Minibatch beam candidates, for faster decoding

This commit is contained in:
Matthew Honnibal 2016-08-08 01:38:50 +02:00
parent eb145dc1b8
commit 0fb188c76c

View File

@ -34,6 +34,7 @@ from thinc.structs cimport FeatureC, ExampleC
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
from thinc.extra.search cimport MaxViolation from thinc.extra.search cimport MaxViolation
from thinc.extra.eg cimport Example from thinc.extra.eg cimport Example
from thinc.extra.mb cimport Minibatch
from ..structs cimport TokenC from ..structs cimport TokenC
@ -136,12 +137,19 @@ cdef class BeamParser(Parser):
nn_model = self.model nn_model = self.model
else: else:
ap_model = self.model ap_model = self.model
raise NotImplementedError
cdef Minibatch mb = Minibatch(nn_model.widths, beam.size)
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) stcls = <StateClass>beam.at(i)
if not stcls.c.is_final(): if stcls.c.is_final():
nr_feat = 0
else:
nr_feat = nn_model._set_featuresC(features, stcls.c) nr_feat = nn_model._set_featuresC(features, stcls.c)
self.model.set_scoresC(beam.scores[i], features, nr_feat)
self.moves.set_valid(beam.is_valid[i], stcls.c) self.moves.set_valid(beam.is_valid[i], stcls.c)
mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
self.model(mb)
for i in range(beam.size):
memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
if gold is not None: if gold is not None:
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) stcls = <StateClass>beam.at(i)