mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Fix beam_parser for new API
This commit is contained in:
parent
2f09b041d1
commit
a664aa8180
|
@ -118,26 +118,28 @@ cdef class BeamParser(Parser):
|
|||
for grad, hist in histories:
|
||||
assert not math.isnan(grad) and not math.isinf(grad)
|
||||
if abs(grad) >= min_grad:
|
||||
self._update_from_history(self.moves, tokens, hist, grad)
|
||||
self.model._update_from_history(self.moves, tokens, hist, grad)
|
||||
_cleanup(pred)
|
||||
_cleanup(gold)
|
||||
return pred.loss
|
||||
|
||||
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
|
||||
cdef Example py_eg = Example(nr_class=self.moves.n_moves, nr_atom=CONTEXT_SIZE,
|
||||
nr_feat=self.model.nr_feat, widths=self.model.widths)
|
||||
cdef ExampleC* eg = py_eg.c
|
||||
|
||||
cdef Pool mem = Pool()
|
||||
features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
|
||||
cdef ParserNeuralNet nn_model = None
|
||||
cdef ParserPerceptron ap_model = None
|
||||
if isinstance(self.model, ParserNeuralNet):
|
||||
nn_model = self.model
|
||||
else:
|
||||
ap_model = self.model
|
||||
for i in range(beam.size):
|
||||
py_eg.reset()
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.c.is_final():
|
||||
self.model.set_featuresC(eg, stcls.c)
|
||||
self.model.set_scoresC(beam.scores[i], eg.features, eg.nr_feat)
|
||||
nr_feat = nn_model._set_featuresC(features, stcls.c)
|
||||
self.model.set_scoresC(beam.scores[i], features, nr_feat)
|
||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||
if gold is not None:
|
||||
for i in range(beam.size):
|
||||
py_eg.reset()
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.c.is_final():
|
||||
self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold)
|
||||
|
|
Loading…
Reference in New Issue
Block a user