mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-06 12:53:19 +03:00
Minibatch beam candidates, for faster decoding
This commit is contained in:
parent
eb145dc1b8
commit
0fb188c76c
|
@ -34,6 +34,7 @@ from thinc.structs cimport FeatureC, ExampleC
|
||||||
from thinc.extra.search cimport Beam
|
from thinc.extra.search cimport Beam
|
||||||
from thinc.extra.search cimport MaxViolation
|
from thinc.extra.search cimport MaxViolation
|
||||||
from thinc.extra.eg cimport Example
|
from thinc.extra.eg cimport Example
|
||||||
|
from thinc.extra.mb cimport Minibatch
|
||||||
|
|
||||||
from ..structs cimport TokenC
|
from ..structs cimport TokenC
|
||||||
|
|
||||||
|
@ -136,12 +137,19 @@ cdef class BeamParser(Parser):
|
||||||
nn_model = self.model
|
nn_model = self.model
|
||||||
else:
|
else:
|
||||||
ap_model = self.model
|
ap_model = self.model
|
||||||
|
raise NotImplementedError
|
||||||
|
cdef Minibatch mb = Minibatch(nn_model.widths, beam.size)
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
stcls = <StateClass>beam.at(i)
|
stcls = <StateClass>beam.at(i)
|
||||||
if not stcls.c.is_final():
|
if stcls.c.is_final():
|
||||||
|
nr_feat = 0
|
||||||
|
else:
|
||||||
nr_feat = nn_model._set_featuresC(features, stcls.c)
|
nr_feat = nn_model._set_featuresC(features, stcls.c)
|
||||||
self.model.set_scoresC(beam.scores[i], features, nr_feat)
|
|
||||||
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
self.moves.set_valid(beam.is_valid[i], stcls.c)
|
||||||
|
mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0)
|
||||||
|
self.model(mb)
|
||||||
|
for i in range(beam.size):
|
||||||
|
memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0]))
|
||||||
if gold is not None:
|
if gold is not None:
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
stcls = <StateClass>beam.at(i)
|
stcls = <StateClass>beam.at(i)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user