Fix beam_parser for new API

This commit is contained in:
Matthew Honnibal 2016-07-31 19:03:10 +02:00
parent 2f09b041d1
commit a664aa8180

View File

@ -118,26 +118,28 @@ cdef class BeamParser(Parser):
for grad, hist in histories: for grad, hist in histories:
assert not math.isnan(grad) and not math.isinf(grad) assert not math.isnan(grad) and not math.isinf(grad)
if abs(grad) >= min_grad: if abs(grad) >= min_grad:
self._update_from_history(self.moves, tokens, hist, grad) self.model._update_from_history(self.moves, tokens, hist, grad)
_cleanup(pred) _cleanup(pred)
_cleanup(gold) _cleanup(gold)
return pred.loss return pred.loss
def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold):
cdef Example py_eg = Example(nr_class=self.moves.n_moves, nr_atom=CONTEXT_SIZE, cdef Pool mem = Pool()
nr_feat=self.model.nr_feat, widths=self.model.widths) features = <FeatureC*>mem.alloc(self.model.nr_feat, sizeof(FeatureC))
cdef ExampleC* eg = py_eg.c cdef ParserNeuralNet nn_model = None
cdef ParserPerceptron ap_model = None
if isinstance(self.model, ParserNeuralNet):
nn_model = self.model
else:
ap_model = self.model
for i in range(beam.size): for i in range(beam.size):
py_eg.reset()
stcls = <StateClass>beam.at(i) stcls = <StateClass>beam.at(i)
if not stcls.c.is_final(): if not stcls.c.is_final():
self.model.set_featuresC(eg, stcls.c) nr_feat = nn_model._set_featuresC(features, stcls.c)
self.model.set_scoresC(beam.scores[i], eg.features, eg.nr_feat) self.model.set_scoresC(beam.scores[i], features, nr_feat)
self.moves.set_valid(beam.is_valid[i], stcls.c) self.moves.set_valid(beam.is_valid[i], stcls.c)
if gold is not None: if gold is not None:
for i in range(beam.size): for i in range(beam.size):
py_eg.reset()
stcls = <StateClass>beam.at(i) stcls = <StateClass>beam.at(i)
if not stcls.c.is_final(): if not stcls.c.is_final():
self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold) self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold)