Tweaks to beam parser

This commit is contained in:
Matthew Honnibal 2017-08-15 03:15:28 -05:00
parent 500e92553d
commit 23537a011d

View File

@ -216,12 +216,13 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)):
p_indices.append([])
g_indices.append([])
if pbeam.loss > 0 and pbeam.min_score > gbeam.score:
if pbeam.loss > 0 and pbeam.min_score > (gbeam.score + nr_update):
continue
for i in range(pbeam.size):
state = <StateClass>pbeam.at(i)
if not state.is_final():
key = tuple([eg_id] + pbeam.histories[i])
assert key not in seen, (key, seen)
seen[key] = len(states)
p_indices[-1].append(len(states))
states.append(state)
@ -257,12 +258,18 @@ def get_gradient(nr_class, beam_maps, histories, losses):
"""
nr_step = len(beam_maps)
grads = []
for beam_map in beam_maps:
if beam_map:
grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f'))
nr_step = 0
for eg_id, hists in enumerate(histories):
for loss, hist in zip(losses[eg_id], hists):
if abs(loss) >= 0.0001 and not numpy.isnan(loss):
nr_step = max(nr_step, len(hist))
for i in range(nr_step):
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class), dtype='f'))
assert len(histories) == len(losses)
for eg_id, hists in enumerate(histories):
for loss, hist in zip(losses[eg_id], hists):
if abs(loss) < 0.0001 or numpy.isnan(loss):
continue
key = tuple([eg_id])
for j, clas in enumerate(hist):
i = beam_maps[j][key]