Minibatch beam candidates, for faster decoding

This commit is contained in:
Matthew Honnibal 2016-08-08 01:38:50 +02:00
parent eb145dc1b8
commit 0fb188c76c

View File

@ -34,6 +34,7 @@ from thinc.structs cimport FeatureC, ExampleC
from thinc.extra.search cimport Beam
from thinc.extra.search cimport MaxViolation
from thinc.extra.eg cimport Example
from thinc.extra.mb cimport Minibatch
from ..structs cimport TokenC
@ -136,12 +137,19 @@ cdef class BeamParser(Parser):
nn_model = self.model
else:
ap_model = self.model
raise NotImplementedError
cdef Minibatch mb = Minibatch(nn_model.widths, beam.size)
for i in range(beam.size):
stcls = <StateClass>beam.at(i)
if not stcls.c.is_final():
if stcls.c.is_final():
nr_feat = 0
else:
nr_feat = nn_model._set_featuresC(features, stcls.c)
self.model.set_scoresC(beam.scores[i], features, nr_feat)
self.moves.set_valid(beam.is_valid[i], stcls.c)
mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
self.model(mb)
for i in range(beam.size):
memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
if gold is not None:
for i in range(beam.size):
stcls = <StateClass>beam.at(i)