Initial scratching on beam early update

This commit is contained in:
Matthew Honnibal 2022-07-15 12:12:47 +02:00
parent 2235e3520c
commit c04ae74268

View File

@ -136,7 +136,7 @@ cdef class BeamBatch(object):
beam.is_valid[i][j] = 0
def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0):
def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0, early_update=False):
cdef MaxViolation violn
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
@ -175,6 +175,17 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
# Track the "maximum violation", to use in the update.
for i, violn in enumerate(violns):
if not dones[i]:
if (
early_update
and _lowest_score_has_cost(pbeam[i])
and pbeam[i].min_score > gbeam[i].score
):
pbeam[i].is_done = True
gbeam[i].is_done = True
# Make sure this state is the 'max violation',
# by setting the previous violn delta
# to -1
violn.delta = -1
violn.check_crf(pbeam[i], gbeam[i])
if pbeam[i].is_done and gbeam[i].is_done:
dones[i] = True
@ -196,6 +207,13 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
return loss
def _lowest_score_has_cost(beam: Beam) -> bool:
"""Check whether the lowest-scoring candidate
in a parse is marked as non-gold, i.e. it has a cost
> 0."""
return beam._states[beam.size-1].loss > 0
def collect_states(beams, docs):
cdef StateClass state
cdef Beam beam