mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Initial scratching on beam early update
This commit is contained in:
parent
2235e3520c
commit
c04ae74268
|
@ -136,7 +136,7 @@ cdef class BeamBatch(object):
|
||||||
beam.is_valid[i][j] = 0
|
beam.is_valid[i][j] = 0
|
||||||
|
|
||||||
|
|
||||||
def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0):
|
def update_beam(TransitionSystem moves, states, golds, model, int width, beam_density=0.0, early_update=False):
|
||||||
cdef MaxViolation violn
|
cdef MaxViolation violn
|
||||||
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
|
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
|
||||||
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
|
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
|
||||||
|
@ -175,6 +175,17 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
|
||||||
# Track the "maximum violation", to use in the update.
|
# Track the "maximum violation", to use in the update.
|
||||||
for i, violn in enumerate(violns):
|
for i, violn in enumerate(violns):
|
||||||
if not dones[i]:
|
if not dones[i]:
|
||||||
|
if (
|
||||||
|
early_update
|
||||||
|
and _lowest_score_has_cost(pbeam[i])
|
||||||
|
and pbeam[i].min_score > gbeam[i].score
|
||||||
|
):
|
||||||
|
pbeam[i].is_done = True
|
||||||
|
gbeam[i].is_done = True
|
||||||
|
# Make sure this state is the 'max violation',
|
||||||
|
# by setting the previous violn delta
|
||||||
|
# to -1
|
||||||
|
violn.delta = -1
|
||||||
violn.check_crf(pbeam[i], gbeam[i])
|
violn.check_crf(pbeam[i], gbeam[i])
|
||||||
if pbeam[i].is_done and gbeam[i].is_done:
|
if pbeam[i].is_done and gbeam[i].is_done:
|
||||||
dones[i] = True
|
dones[i] = True
|
||||||
|
@ -196,6 +207,13 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _lowest_score_has_cost(beam: Beam) -> bool:
|
||||||
|
"""Check whether the lowest-scoring candidate
|
||||||
|
in a parse is marked as non-gold, i.e. it has a cost
|
||||||
|
> 0."""
|
||||||
|
return beam._states[beam.size-1].loss > 0
|
||||||
|
|
||||||
|
|
||||||
def collect_states(beams, docs):
|
def collect_states(beams, docs):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef Beam beam
|
cdef Beam beam
|
||||||
|
|
Loading…
Reference in New Issue
Block a user