mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Fix error in beam gradient calculation
This commit is contained in:
parent
a6ae1ee6f7
commit
a61fd60681
|
@ -133,7 +133,7 @@ cdef class ParserBeam(object):
|
||||||
if follow_gold:
|
if follow_gold:
|
||||||
min_cost = 0
|
min_cost = 0
|
||||||
for j in range(beam.nr_class):
|
for j in range(beam.nr_class):
|
||||||
if beam.costs[i][j] < min_cost:
|
if beam.is_valid[i][j] and beam.costs[i][j] < min_cost:
|
||||||
min_cost = beam.costs[i][j]
|
min_cost = beam.costs[i][j]
|
||||||
for j in range(beam.nr_class):
|
for j in range(beam.nr_class):
|
||||||
if beam.costs[i][j] > min_cost:
|
if beam.costs[i][j] > min_cost:
|
||||||
|
@ -160,7 +160,8 @@ nr_update = 0
|
||||||
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
||||||
states, golds,
|
states, golds,
|
||||||
state2vec, vec2scores,
|
state2vec, vec2scores,
|
||||||
int width, losses=None, drop=0.):
|
int width, losses=None, drop=0.,
|
||||||
|
early_update=True):
|
||||||
global nr_update
|
global nr_update
|
||||||
cdef MaxViolation violn
|
cdef MaxViolation violn
|
||||||
nr_update += 1
|
nr_update += 1
|
||||||
|
@ -201,13 +202,16 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
||||||
for indices in p_indices]
|
for indices in p_indices]
|
||||||
g_scores = [numpy.ascontiguousarray(scores[indices], dtype='f')
|
g_scores = [numpy.ascontiguousarray(scores[indices], dtype='f')
|
||||||
for indices in g_indices]
|
for indices in g_indices]
|
||||||
# Now advance the states in the beams. The gold beam is contrained to
|
# Now advance the states in the beams. The gold beam is constrained to
|
||||||
# to follow only gold analyses.
|
# to follow only gold analyses.
|
||||||
pbeam.advance(p_scores)
|
pbeam.advance(p_scores)
|
||||||
gbeam.advance(g_scores, follow_gold=True)
|
gbeam.advance(g_scores, follow_gold=True)
|
||||||
# Track the "maximum violation", to use in the update.
|
# Track the "maximum violation", to use in the update.
|
||||||
for i, violn in enumerate(violns):
|
for i, violn in enumerate(violns):
|
||||||
violn.check_crf(pbeam[i], gbeam[i])
|
violn.check_crf(pbeam[i], gbeam[i])
|
||||||
|
if pbeam[i].loss > 0 and pbeam[i].min_score > (gbeam[i].score * 1.5):
|
||||||
|
pbeam.dones[i] = True
|
||||||
|
gbeam.dones[i] = True
|
||||||
histories = []
|
histories = []
|
||||||
losses = []
|
losses = []
|
||||||
for violn in violns:
|
for violn in violns:
|
||||||
|
@ -271,14 +275,15 @@ def get_gradient(nr_class, beam_maps, histories, losses):
|
||||||
Each batch has multiple beams
|
Each batch has multiple beams
|
||||||
So history is list of lists of lists of ints
|
So history is list of lists of lists of ints
|
||||||
"""
|
"""
|
||||||
nr_step = len(beam_maps)
|
|
||||||
grads = []
|
grads = []
|
||||||
nr_step = 0
|
nr_steps = []
|
||||||
for eg_id, hists in enumerate(histories):
|
for eg_id, hists in enumerate(histories):
|
||||||
|
nr_step = 0
|
||||||
for loss, hist in zip(losses[eg_id], hists):
|
for loss, hist in zip(losses[eg_id], hists):
|
||||||
if loss != 0.0 and not numpy.isnan(loss):
|
if loss != 0.0 and not numpy.isnan(loss):
|
||||||
nr_step = max(nr_step, len(hist))
|
nr_step = max(nr_step, len(hist))
|
||||||
for i in range(nr_step):
|
nr_steps.append(nr_step)
|
||||||
|
for i in range(max(nr_steps)):
|
||||||
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class),
|
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class),
|
||||||
dtype='f'))
|
dtype='f'))
|
||||||
if len(histories) != len(losses):
|
if len(histories) != len(losses):
|
||||||
|
@ -289,8 +294,11 @@ def get_gradient(nr_class, beam_maps, histories, losses):
|
||||||
continue
|
continue
|
||||||
key = tuple([eg_id])
|
key = tuple([eg_id])
|
||||||
# Adjust loss for length
|
# Adjust loss for length
|
||||||
|
# We need to do this because each state in a short path is scored
|
||||||
|
# multiple times, as we add in the average cost when we run out
|
||||||
|
# of actions.
|
||||||
avg_loss = loss / len(hist)
|
avg_loss = loss / len(hist)
|
||||||
loss += avg_loss * (nr_step - len(hist))
|
loss += avg_loss * (nr_steps[eg_id] - len(hist))
|
||||||
for j, clas in enumerate(hist):
|
for j, clas in enumerate(hist):
|
||||||
i = beam_maps[j][key]
|
i = beam_maps[j][key]
|
||||||
# In step j, at state i action clas
|
# In step j, at state i action clas
|
||||||
|
|
Loading…
Reference in New Issue
Block a user