mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-15 02:32:37 +03:00
* Work on implementing a trainable cache for the parser. So far, doesn't improve efficiency
This commit is contained in:
parent
033d6c9ac2
commit
53b8bc1f3c
|
@ -31,6 +31,7 @@ cdef class DecisionMemory:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cdef PreshCounter _counts
|
cdef PreshCounter _counts
|
||||||
cdef PreshCounter _class_counts
|
cdef PreshCounter _class_counts
|
||||||
|
cdef PreshMap memos
|
||||||
cdef list class_names
|
cdef list class_names
|
||||||
|
|
||||||
cdef int inc(self, hash_t context_key, hash_t clas, count_t inc) except -1
|
cdef int inc(self, hash_t context_key, hash_t clas, count_t inc) except -1
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""Create a term-document matrix"""
|
"""Create a term-document matrix"""
|
||||||
cimport cython
|
cimport cython
|
||||||
|
from libc.stdint cimport int64_t
|
||||||
from libc.string cimport memmove
|
from libc.string cimport memmove
|
||||||
|
|
||||||
from cymem.cymem cimport Address
|
from cymem.cymem cimport Address
|
||||||
|
@ -42,6 +42,20 @@ cdef class DecisionMemory:
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self._counts = PreshCounter()
|
self._counts = PreshCounter()
|
||||||
self._class_counts = PreshCounter()
|
self._class_counts = PreshCounter()
|
||||||
|
self.memos = PreshMap()
|
||||||
|
|
||||||
|
def load(self, loc, thresh=0):
|
||||||
|
cdef:
|
||||||
|
count_t freq
|
||||||
|
hash_t key
|
||||||
|
int clas
|
||||||
|
for line in open(loc):
|
||||||
|
freq, key, clas = [int(p) for p in line.split()]
|
||||||
|
if thresh == 0 or freq >= thresh:
|
||||||
|
self.memos.set(key, <void*>(clas+1))
|
||||||
|
|
||||||
|
def get(self, hash_t context_key):
|
||||||
|
return <int64_t>self.memos.get(context_key) - 1
|
||||||
|
|
||||||
def __getitem__(self, ids):
|
def __getitem__(self, ids):
|
||||||
cdef id_t[2] context
|
cdef id_t[2] context
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
|
from libc.stdint cimport uint32_t, uint64_t
|
||||||
from thinc.features cimport Extractor
|
from thinc.features cimport Extractor
|
||||||
from thinc.learner cimport LinearModel
|
from thinc.learner cimport LinearModel
|
||||||
|
|
||||||
from .arc_eager cimport TransitionSystem
|
from .arc_eager cimport TransitionSystem
|
||||||
|
|
||||||
from ..tokens cimport Tokens, TokenC
|
from ..tokens cimport Tokens, TokenC
|
||||||
|
from ._state cimport State
|
||||||
|
from ..index cimport DecisionMemory
|
||||||
|
|
||||||
|
|
||||||
cdef class GreedyParser:
|
cdef class GreedyParser:
|
||||||
|
@ -11,5 +14,6 @@ cdef class GreedyParser:
|
||||||
cdef Extractor extractor
|
cdef Extractor extractor
|
||||||
cdef readonly LinearModel model
|
cdef readonly LinearModel model
|
||||||
cdef TransitionSystem moves
|
cdef TransitionSystem moves
|
||||||
|
cdef readonly DecisionMemory guess_cache
|
||||||
|
|
||||||
cpdef int parse(self, Tokens tokens) except -1
|
cpdef int parse(self, Tokens tokens) except -1
|
||||||
|
|
|
@ -4,6 +4,7 @@ MALT-style dependency parser
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
cimport cython
|
cimport cython
|
||||||
|
from libc.stdint cimport uint32_t, uint64_t
|
||||||
import random
|
import random
|
||||||
import os.path
|
import os.path
|
||||||
from os.path import join as pjoin
|
from os.path import join as pjoin
|
||||||
|
@ -11,6 +12,7 @@ import shutil
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from cymem.cymem cimport Pool, Address
|
from cymem.cymem cimport Pool, Address
|
||||||
|
from murmurhash.mrmr cimport hash64
|
||||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
|
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +28,7 @@ from ..tokens cimport Tokens, TokenC
|
||||||
|
|
||||||
from .arc_eager cimport TransitionSystem, Transition
|
from .arc_eager cimport TransitionSystem, Transition
|
||||||
|
|
||||||
from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1
|
from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1
|
||||||
|
|
||||||
from . import _parse_features
|
from . import _parse_features
|
||||||
from ._parse_features cimport fill_context, CONTEXT_SIZE
|
from ._parse_features cimport fill_context, CONTEXT_SIZE
|
||||||
|
@ -63,24 +65,39 @@ cdef class GreedyParser:
|
||||||
self.extractor = Extractor(get_templates(self.cfg.features))
|
self.extractor = Extractor(get_templates(self.cfg.features))
|
||||||
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
|
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
|
||||||
self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ)
|
self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ)
|
||||||
|
# Classes for decision memory
|
||||||
|
classes = ['S', 'D']
|
||||||
|
classes += ['L-%s' % label for label in self.cfg.left_labels]
|
||||||
|
classes += ['R-%s' % label for label in self.cfg.right_labels]
|
||||||
|
self.guess_cache = DecisionMemory(classes)
|
||||||
if os.path.exists(pjoin(model_dir, 'model')):
|
if os.path.exists(pjoin(model_dir, 'model')):
|
||||||
self.model.load(pjoin(model_dir, 'model'))
|
self.model.load(pjoin(model_dir, 'model'))
|
||||||
|
if os.path.exists(pjoin(model_dir, 'guess_cache')):
|
||||||
|
self.guess_cache.load(pjoin(model_dir, 'guess_cache'))
|
||||||
|
|
||||||
cpdef int parse(self, Tokens tokens) except -1:
|
cpdef int parse(self, Tokens tokens) except -1:
|
||||||
cdef:
|
cdef:
|
||||||
const Feature* feats
|
const Feature* feats
|
||||||
const weight_t* scores
|
const weight_t* scores
|
||||||
Transition guess
|
Transition guess
|
||||||
|
uint64_t state_key
|
||||||
|
|
||||||
cdef atom_t[CONTEXT_SIZE] context
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
cdef int n_feats
|
cdef int n_feats
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef State* state = init_state(mem, tokens.data, tokens.length)
|
cdef State* state = init_state(mem, tokens.data, tokens.length)
|
||||||
|
cdef int guess_clas
|
||||||
while not is_final(state):
|
while not is_final(state):
|
||||||
|
state_key = _approx_hash_state(state)
|
||||||
|
guess_clas = self.guess_cache.get(state_key)
|
||||||
|
if guess_clas == -1:
|
||||||
fill_context(context, state)
|
fill_context(context, state)
|
||||||
feats = self.extractor.get_feats(context, &n_feats)
|
feats = self.extractor.get_feats(context, &n_feats)
|
||||||
scores = self.model.get_scores(feats, n_feats)
|
scores = self.model.get_scores(feats, n_feats)
|
||||||
guess = self.moves.best_valid(scores, state)
|
guess = self.moves.best_valid(scores, state)
|
||||||
|
self.guess_cache.inc(state_key, guess.clas, 1)
|
||||||
|
else:
|
||||||
|
guess = self.moves._moves[guess_clas]
|
||||||
self.moves.transition(state, &guess)
|
self.moves.transition(state, &guess)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@ -117,6 +134,17 @@ cdef class GreedyParser:
|
||||||
return n_corr
|
return n_corr
|
||||||
|
|
||||||
|
|
||||||
|
cdef uint64_t _approx_hash_state(const State* state) except 0:
|
||||||
|
cdef int[3] context
|
||||||
|
context[0] = get_s0(state).lex.sic
|
||||||
|
context[1] = get_n0(state).lex.sic
|
||||||
|
if get_n1(state):
|
||||||
|
context[2] = get_n1(state).pos
|
||||||
|
else:
|
||||||
|
context[2] = 0
|
||||||
|
return hash64(context, sizeof(int) * 3, 0)
|
||||||
|
|
||||||
|
|
||||||
cdef dict _get_counts(int guess, int best, const Feature* feats, const int n_feats,
|
cdef dict _get_counts(int guess, int best, const Feature* feats, const int n_feats,
|
||||||
int inc):
|
int inc):
|
||||||
if guess == best:
|
if guess == best:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user