diff --git a/spacy/pipeline/_parser_internals/_beam_utils.pyx b/spacy/pipeline/_parser_internals/_beam_utils.pyx index fa7df2056..202ffb77e 100644 --- a/spacy/pipeline/_parser_internals/_beam_utils.pyx +++ b/spacy/pipeline/_parser_internals/_beam_utils.pyx @@ -136,7 +136,7 @@ cdef class BeamBatch(object): beam.is_valid[i][j] = 0 -def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0): +def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0, early_update=False): cdef MaxViolation violn pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density) gbeam = BeamBatch(moves, states, golds, width=width, density=0.0) @@ -175,6 +175,17 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de # Track the "maximum violation", to use in the update. for i, violn in enumerate(violns): if not dones[i]: + if ( + early_update + and _lowest_score_has_cost(pbeam[i]) + and pbeam[i].min_score > gbeam[i].score + ): + pbeam[i].is_done = True + gbeam[i].is_done = True + # Make sure this state is the 'max violation', + # by setting the previous violn delta + # to -1 + violn.delta = -1 violn.check_crf(pbeam[i], gbeam[i]) if pbeam[i].is_done and gbeam[i].is_done: dones[i] = True @@ -196,6 +207,13 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de return loss +def _lowest_score_has_cost(beam: Beam) -> bool: + """Check whether the lowest-scoring candidate + in a parse is marked as non-gold, i.e. it has a cost + > 0.""" + return beam._states[beam.size-1].loss > 0 + + def collect_states(beams, docs): cdef StateClass state cdef Beam beam