mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
* Work on implementing a trainable cache for the parser. So far, doesn't improve efficiency
This commit is contained in:
parent
033d6c9ac2
commit
53b8bc1f3c
|
@ -31,6 +31,7 @@ cdef class DecisionMemory:
|
|||
cdef Pool mem
|
||||
cdef PreshCounter _counts
|
||||
cdef PreshCounter _class_counts
|
||||
cdef PreshMap memos
|
||||
cdef list class_names
|
||||
|
||||
cdef int inc(self, hash_t context_key, hash_t clas, count_t inc) except -1
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Create a term-document matrix"""
|
||||
cimport cython
|
||||
|
||||
from libc.stdint cimport int64_t
|
||||
from libc.string cimport memmove
|
||||
|
||||
from cymem.cymem cimport Address
|
||||
|
@ -42,6 +42,20 @@ cdef class DecisionMemory:
|
|||
self.mem = Pool()
|
||||
self._counts = PreshCounter()
|
||||
self._class_counts = PreshCounter()
|
||||
self.memos = PreshMap()
|
||||
|
||||
def load(self, loc, thresh=0):
|
||||
cdef:
|
||||
count_t freq
|
||||
hash_t key
|
||||
int clas
|
||||
for line in open(loc):
|
||||
freq, key, clas = [int(p) for p in line.split()]
|
||||
if thresh == 0 or freq >= thresh:
|
||||
self.memos.set(key, <void*>(clas+1))
|
||||
|
||||
def get(self, hash_t context_key):
|
||||
return <int64_t>self.memos.get(context_key) - 1
|
||||
|
||||
def __getitem__(self, ids):
|
||||
cdef id_t[2] context
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
from libc.stdint cimport uint32_t, uint64_t
|
||||
from thinc.features cimport Extractor
|
||||
from thinc.learner cimport LinearModel
|
||||
|
||||
from .arc_eager cimport TransitionSystem
|
||||
|
||||
from ..tokens cimport Tokens, TokenC
|
||||
from ._state cimport State
|
||||
from ..index cimport DecisionMemory
|
||||
|
||||
|
||||
cdef class GreedyParser:
|
||||
|
@ -11,5 +14,6 @@ cdef class GreedyParser:
|
|||
cdef Extractor extractor
|
||||
cdef readonly LinearModel model
|
||||
cdef TransitionSystem moves
|
||||
cdef readonly DecisionMemory guess_cache
|
||||
|
||||
cpdef int parse(self, Tokens tokens) except -1
|
||||
|
|
|
@ -4,6 +4,7 @@ MALT-style dependency parser
|
|||
"""
|
||||
from __future__ import unicode_literals
|
||||
cimport cython
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
import random
|
||||
import os.path
|
||||
from os.path import join as pjoin
|
||||
|
@ -11,6 +12,7 @@ import shutil
|
|||
import json
|
||||
|
||||
from cymem.cymem cimport Pool, Address
|
||||
from murmurhash.mrmr cimport hash64
|
||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t
|
||||
|
||||
|
||||
|
@ -26,7 +28,7 @@ from ..tokens cimport Tokens, TokenC
|
|||
|
||||
from .arc_eager cimport TransitionSystem, Transition
|
||||
|
||||
from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1
|
||||
from ._state cimport init_state, State, is_final, get_idx, get_s0, get_s1, get_n0, get_n1
|
||||
|
||||
from . import _parse_features
|
||||
from ._parse_features cimport fill_context, CONTEXT_SIZE
|
||||
|
@ -63,24 +65,39 @@ cdef class GreedyParser:
|
|||
self.extractor = Extractor(get_templates(self.cfg.features))
|
||||
self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels)
|
||||
self.model = LinearModel(self.moves.n_moves, self.extractor.n_templ)
|
||||
# Classes for decision memory
|
||||
classes = ['S', 'D']
|
||||
classes += ['L-%s' % label for label in self.cfg.left_labels]
|
||||
classes += ['R-%s' % label for label in self.cfg.right_labels]
|
||||
self.guess_cache = DecisionMemory(classes)
|
||||
if os.path.exists(pjoin(model_dir, 'model')):
|
||||
self.model.load(pjoin(model_dir, 'model'))
|
||||
if os.path.exists(pjoin(model_dir, 'guess_cache')):
|
||||
self.guess_cache.load(pjoin(model_dir, 'guess_cache'))
|
||||
|
||||
cpdef int parse(self, Tokens tokens) except -1:
|
||||
cdef:
|
||||
const Feature* feats
|
||||
const weight_t* scores
|
||||
Transition guess
|
||||
uint64_t state_key
|
||||
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef int n_feats
|
||||
cdef Pool mem = Pool()
|
||||
cdef State* state = init_state(mem, tokens.data, tokens.length)
|
||||
cdef int guess_clas
|
||||
while not is_final(state):
|
||||
fill_context(context, state)
|
||||
feats = self.extractor.get_feats(context, &n_feats)
|
||||
scores = self.model.get_scores(feats, n_feats)
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
state_key = _approx_hash_state(state)
|
||||
guess_clas = self.guess_cache.get(state_key)
|
||||
if guess_clas == -1:
|
||||
fill_context(context, state)
|
||||
feats = self.extractor.get_feats(context, &n_feats)
|
||||
scores = self.model.get_scores(feats, n_feats)
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
self.guess_cache.inc(state_key, guess.clas, 1)
|
||||
else:
|
||||
guess = self.moves._moves[guess_clas]
|
||||
self.moves.transition(state, &guess)
|
||||
return 0
|
||||
|
||||
|
@ -117,6 +134,17 @@ cdef class GreedyParser:
|
|||
return n_corr
|
||||
|
||||
|
||||
cdef uint64_t _approx_hash_state(const State* state) except 0:
|
||||
cdef int[3] context
|
||||
context[0] = get_s0(state).lex.sic
|
||||
context[1] = get_n0(state).lex.sic
|
||||
if get_n1(state):
|
||||
context[2] = get_n1(state).pos
|
||||
else:
|
||||
context[2] = 0
|
||||
return hash64(context, sizeof(int) * 3, 0)
|
||||
|
||||
|
||||
cdef dict _get_counts(int guess, int best, const Feature* feats, const int n_feats,
|
||||
int inc):
|
||||
if guess == best:
|
||||
|
|
Loading…
Reference in New Issue
Block a user