mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Minibatch beam candidates, for faster decoding
This commit is contained in:
parent
eb145dc1b8
commit
0fb188c76c
|
@ -34,6 +34,7 @@ from thinc.structs cimport FeatureC, ExampleC
|
|||
from thinc.extra.search cimport Beam
|
||||
from thinc.extra.search cimport MaxViolation
|
||||
from thinc.extra.eg cimport Example
|
||||
from thinc.extra.mb cimport Minibatch
|
||||
|
||||
from ..structs cimport TokenC
|
||||
|
||||
|
@ -136,12 +137,19 @@ cdef class BeamParser(Parser):
|
|||
nn_model = self.model
|
||||
else:
|
||||
ap_model = self.model
|
||||
raise NotImplementedError
|
||||
cdef Minibatch mb = Minibatch(nn_model.widths, beam.size)
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
if not stcls.c.is_final():
|
||||
if stcls.c.is_final():
|
||||
nr_feat = 0
|
||||
else:
|
||||
nr_feat = nn_model._set_featuresC(features, stcls.c)
|
||||
self.model.set_scoresC(beam.scores[i], features, nr_feat)
|
||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||
mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
|
||||
self.model(mb)
|
||||
for i in range(beam.size):
|
||||
memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
|
||||
if gold is not None:
|
||||
for i in range(beam.size):
|
||||
stcls = <StateClass>beam.at(i)
|
||||
|
|
Loading…
Reference in New Issue
Block a user