mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Initial scratching on beam early update
This commit is contained in:
parent
2235e3520c
commit
c04ae74268
|
@ -136,7 +136,7 @@ cdef class BeamBatch(object):
|
|||
beam.is_valid[i][j] = 0
|
||||
|
||||
|
||||
def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0):
|
||||
def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0, early_update=False):
|
||||
cdef MaxViolation violn
|
||||
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
|
||||
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
|
||||
|
@ -175,6 +175,17 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
|
|||
# Track the "maximum violation", to use in the update.
|
||||
for i, violn in enumerate(violns):
|
||||
if not dones[i]:
|
||||
if (
|
||||
early_update
|
||||
and _lowest_score_has_cost(pbeam[i])
|
||||
and pbeam[i].min_score > gbeam[i].score
|
||||
):
|
||||
pbeam[i].is_done = True
|
||||
gbeam[i].is_done = True
|
||||
# Make sure this state is the 'max violation',
|
||||
# by setting the previous violn delta
|
||||
# to -1
|
||||
violn.delta = -1
|
||||
violn.check_crf(pbeam[i], gbeam[i])
|
||||
if pbeam[i].is_done and gbeam[i].is_done:
|
||||
dones[i] = True
|
||||
|
@ -196,6 +207,13 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
|
|||
return loss
|
||||
|
||||
|
||||
def _lowest_score_has_cost(beam: Beam) -> bool:
|
||||
"""Check whether the lowest-scoring candidate
|
||||
in a parse is marked as non-gold, i.e. it has a cost
|
||||
> 0."""
|
||||
return beam._states[beam.size-1].loss > 0
|
||||
|
||||
|
||||
def collect_states(beams, docs):
|
||||
cdef StateClass state
|
||||
cdef Beam beam
|
||||
|
|
Loading…
Reference in New Issue
Block a user