mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
* Use new Example class
This commit is contained in:
parent
735f1af91f
commit
f4986d5d3c
|
@ -5,6 +5,7 @@ from cymem.cymem cimport Pool
|
|||
from thinc.learner cimport LinearModel
|
||||
from thinc.features cimport Extractor, Feature
|
||||
from thinc.typedefs cimport atom_t, feat_t, weight_t, class_t
|
||||
from thinc.api cimport ExampleC
|
||||
|
||||
from preshed.maps cimport PreshMapArray
|
||||
|
||||
|
|
|
@ -61,18 +61,14 @@ cdef class Model:
|
|||
self._model.load(self.model_loc, freq_thresh=0)
|
||||
|
||||
def predict(self, Example eg):
|
||||
self.set_scores(&eg.scores[0], &eg.atoms[0])
|
||||
eg.guess = arg_max_if_true(&eg.scores[0], &eg.is_valid[0],
|
||||
self.n_classes)
|
||||
self.set_scores(eg.c.scores, eg.c.atoms)
|
||||
eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)
|
||||
|
||||
def train(self, Example eg):
|
||||
self.set_scores(&eg.scores[0], &eg.atoms[0])
|
||||
eg.guess = arg_max_if_true(&eg.scores[0],
|
||||
&eg.is_valid[0], self.n_classes)
|
||||
eg.best = arg_max_if_zero(&eg.scores[0], &eg.costs[0],
|
||||
self.n_classes)
|
||||
eg.cost = eg.costs[eg.guess]
|
||||
self.update(&eg.atoms[0], eg.guess, eg.best, eg.cost)
|
||||
self.predict(eg)
|
||||
eg.c.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes)
|
||||
eg.c.cost = eg.c.costs[eg.c.guess]
|
||||
self.update(eg.c.atoms, eg.c.guess, eg.c.best, eg.c.cost)
|
||||
|
||||
cdef const weight_t* score(self, atom_t* context) except NULL:
|
||||
cdef int n_feats
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from thinc.api cimport Example
|
||||
from thinc.api cimport Example, ExampleC
|
||||
from thinc.typedefs cimport weight_t
|
||||
|
||||
from ._ml cimport arg_max_if_true
|
||||
|
@ -33,20 +33,17 @@ cdef class TheanoModel(Model):
|
|||
cdef int i
|
||||
for i in range(self.n_classes):
|
||||
eg.scores[i] = theano_scores[i]
|
||||
eg.guess = arg_max_if_true(&eg.scores[0], <int*>eg.is_valid[0],
|
||||
self.n_classes)
|
||||
eg.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)
|
||||
|
||||
def train(self, Example eg):
|
||||
self.input_layer.fill(eg.embeddings, eg.atoms, use_avg=False)
|
||||
theano_scores, update, y = self.train_func(eg.embeddings, eg.costs, self.eta)
|
||||
self.input_layer.update(update, eg.atoms, self.t, self.eta, self.mu)
|
||||
for i in range(self.n_classes):
|
||||
eg.scores[i] = theano_scores[i]
|
||||
eg.guess = arg_max_if_true(&eg.scores[0], <int*>eg.is_valid[0],
|
||||
self.n_classes)
|
||||
eg.best = arg_max_if_zero(&eg.scores[0], <int*>eg.costs[0],
|
||||
self.n_classes)
|
||||
eg.cost = eg.costs[eg.guess]
|
||||
eg.c.scores[i] = theano_scores[i]
|
||||
eg.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)
|
||||
eg.best = arg_max_if_zero(eg.c.scores, eg.c.costs, self.n_classes)
|
||||
eg.cost = eg.c.costs[eg.guess]
|
||||
self.t += 1
|
||||
|
||||
def end_training(self):
|
||||
|
|
|
@ -71,14 +71,17 @@ cdef class Parser:
|
|||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
|
||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE, self.model.n_feats)
|
||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
||||
self.model.n_feats, self.model.n_feats)
|
||||
while not stcls.is_final():
|
||||
eg.wipe()
|
||||
fill_context(&eg.atoms[0], stcls)
|
||||
self.moves.set_valid(<bint*>&eg.is_valid[0], stcls)
|
||||
memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t))
|
||||
|
||||
self.moves.set_valid(<bint*>eg.c.is_valid, stcls)
|
||||
fill_context(eg.c.atoms, stcls)
|
||||
|
||||
self.model.predict(eg)
|
||||
|
||||
self.moves.c[eg.guess].do(stcls, self.moves.c[eg.guess].label)
|
||||
self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label)
|
||||
self.moves.finalize_state(stcls)
|
||||
tokens.set_parse(stcls._sent)
|
||||
|
||||
|
@ -86,20 +89,24 @@ cdef class Parser:
|
|||
self.moves.preprocess_gold(gold)
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE, self.model.n_feats)
|
||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
||||
self.model.n_feats, self.model.n_feats)
|
||||
cdef int cost = 0
|
||||
while not stcls.is_final():
|
||||
eg.wipe()
|
||||
fill_context(&eg.atoms[0], stcls)
|
||||
self.moves.set_costs(<bint*>&eg.is_valid[0], &eg.costs[0], stcls, gold)
|
||||
memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t))
|
||||
|
||||
self.moves.set_costs(<bint*>eg.c.is_valid, eg.c.costs, stcls, gold)
|
||||
|
||||
fill_context(eg.c.atoms, stcls)
|
||||
|
||||
self.model.train(eg)
|
||||
|
||||
self.moves.c[eg.guess].do(stcls, self.moves.c[eg.guess].label)
|
||||
cost += eg.cost
|
||||
self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label)
|
||||
cost += eg.c.cost
|
||||
return cost
|
||||
|
||||
|
||||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
"""
|
||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||
|
|
Loading…
Reference in New Issue
Block a user