Update beam parser

This commit is contained in:
Matthew Honnibal 2017-08-16 18:25:49 -05:00
parent 4b1e7bd6d8
commit 0209a06b4e
3 changed files with 53 additions and 49 deletions

View File

@ -6,6 +6,7 @@ from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
from thinc.extra.search cimport Beam from thinc.extra.search cimport Beam
from thinc.extra.search import MaxViolation from thinc.extra.search import MaxViolation
from thinc.typedefs cimport hash_t, class_t from thinc.typedefs cimport hash_t, class_t
from thinc.extra.search cimport MaxViolation
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport TransitionSystem, Transition
from .stateclass cimport StateClass from .stateclass cimport StateClass
@ -45,6 +46,7 @@ cdef class ParserBeam(object):
cdef public object states cdef public object states
cdef public object golds cdef public object golds
cdef public object beams cdef public object beams
cdef public object dones
def __init__(self, TransitionSystem moves, states, golds, def __init__(self, TransitionSystem moves, states, golds,
int width=4, float density=0.001): int width=4, float density=0.001):
@ -61,6 +63,7 @@ cdef class ParserBeam(object):
st = <StateClass>beam.at(i) st = <StateClass>beam.at(i)
st.c.offset = state.c.offset st.c.offset = state.c.offset
self.beams.append(beam) self.beams.append(beam)
self.dones = [False] * len(self.beams)
def __dealloc__(self): def __dealloc__(self):
if self.beams is not None: if self.beams is not None:
@ -70,7 +73,7 @@ cdef class ParserBeam(object):
@property @property
def is_done(self): def is_done(self):
return all(b.is_done for b in self.beams) return all(b.is_done or self.dones[i] for i, b in enumerate(self.beams))
def __getitem__(self, i): def __getitem__(self, i):
return self.beams[i] return self.beams[i]
@ -81,19 +84,24 @@ cdef class ParserBeam(object):
def advance(self, scores, follow_gold=False): def advance(self, scores, follow_gold=False):
cdef Beam beam cdef Beam beam
for i, beam in enumerate(self.beams): for i, beam in enumerate(self.beams):
if beam.is_done or not scores[i].size: if beam.is_done or not scores[i].size or self.dones[i]:
continue continue
self._set_scores(beam, scores[i]) self._set_scores(beam, scores[i])
if self.golds is not None: if self.golds is not None:
self._set_costs(beam, self.golds[i], follow_gold=follow_gold) self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
beam.advance(_transition_state, NULL, <void*>self.moves.c) beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
if beam.is_done: if beam.is_done and self.golds is not None:
for j in range(beam.size): for j in range(beam.size):
if is_gold(<StateClass>beam.at(j), self.golds[i], self.moves.strings): state = <StateClass>beam.at(j)
beam._states[j].loss = 0.0 if state.is_final():
elif beam._states[j].loss == 0.0: try:
beam._states[j].loss = 1.0 if self.moves.is_gold_parse(state, self.golds[i]):
beam._states[j].loss = 0.0
elif beam._states[j].loss == 0.0:
beam._states[j].loss = 1.0
except NotImplementedError:
break
def _set_scores(self, Beam beam, float[:, ::1] scores): def _set_scores(self, Beam beam, float[:, ::1] scores):
cdef float* c_scores = &scores[0, 0] cdef float* c_scores = &scores[0, 0]
@ -110,7 +118,6 @@ cdef class ParserBeam(object):
beam.scores[i][j] = 0 beam.scores[i][j] = 0
beam.costs[i][j] = 0 beam.costs[i][j] = 0
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False): def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
for i in range(beam.size): for i in range(beam.size):
state = <StateClass>beam.at(i) state = <StateClass>beam.at(i)
@ -122,21 +129,6 @@ cdef class ParserBeam(object):
beam.is_valid[i][j] = 0 beam.is_valid[i][j] = 0
def is_gold(StateClass state, GoldParse gold, strings):
predicted = set()
truth = set()
for i in range(gold.length):
if gold.cand_to_gold[i] is None:
continue
if state.safe_get(i).dep:
predicted.add((i, state.H(i), strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
truth.add((id_, head, dep))
return truth == predicted
def get_token_ids(states, int n_tokens): def get_token_ids(states, int n_tokens):
cdef StateClass state cdef StateClass state
cdef np.ndarray ids = numpy.zeros((len(states), n_tokens), cdef np.ndarray ids = numpy.zeros((len(states), n_tokens),
@ -156,16 +148,19 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
state2vec, vec2scores, drop=0., sgd=None, state2vec, vec2scores, drop=0., sgd=None,
losses=None, int width=4, float density=0.001): losses=None, int width=4, float density=0.001):
global nr_update global nr_update
cdef MaxViolation violn
nr_update += 1 nr_update += 1
pbeam = ParserBeam(moves, states, golds, pbeam = ParserBeam(moves, states, golds,
width=width, density=density) width=width, density=density)
gbeam = ParserBeam(moves, states, golds, gbeam = ParserBeam(moves, states, golds,
width=width, density=0.0) width=width, density=density)
cdef StateClass state cdef StateClass state
beam_maps = [] beam_maps = []
backprops = [] backprops = []
violns = [MaxViolation() for _ in range(len(states))] violns = [MaxViolation() for _ in range(len(states))]
for t in range(max_steps): for t in range(max_steps):
if pbeam.is_done and gbeam.is_done:
break
# The beam maps let us find the right row in the flattened scores # The beam maps let us find the right row in the flattened scores
# arrays for each state. States are identified by (example id, history). # arrays for each state. States are identified by (example id, history).
# We keep a different beam map for each step (since we'll have a flat # We keep a different beam map for each step (since we'll have a flat
@ -197,12 +192,16 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
# Track the "maximum violation", to use in the update. # Track the "maximum violation", to use in the update.
for i, violn in enumerate(violns): for i, violn in enumerate(violns):
violn.check_crf(pbeam[i], gbeam[i]) violn.check_crf(pbeam[i], gbeam[i])
histories = []
# Only make updates if we have non-gold states losses = []
histories = [((v.p_hist + v.g_hist) if v.p_hist else []) for v in violns] for i, violn in enumerate(violns):
losses = [((v.p_probs + v.g_probs) if v.p_probs else []) for v in violns] if violn.cost < 1:
states_d_scores = get_gradient(moves.n_moves, beam_maps, histories.append([])
histories, losses) losses.append([])
else:
histories.append(violn.p_hist + violn.g_hist)
losses.append(violn.p_probs + violn.g_probs)
states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, losses)
return states_d_scores, backprops[:len(states_d_scores)] return states_d_scores, backprops[:len(states_d_scores)]
@ -216,7 +215,9 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)): for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)):
p_indices.append([]) p_indices.append([])
g_indices.append([]) g_indices.append([])
if pbeam.loss > 0 and pbeam.min_score > (gbeam.score + nr_update): if pbeam.loss > 0 and pbeam.min_score > (gbeam.score + numpy.sqrt(nr_update)):
pbeams.dones[eg_id] = True
gbeams.dones[eg_id] = True
continue continue
for i in range(pbeam.size): for i in range(pbeam.size):
state = <StateClass>pbeam.at(i) state = <StateClass>pbeam.at(i)
@ -261,21 +262,21 @@ def get_gradient(nr_class, beam_maps, histories, losses):
nr_step = 0 nr_step = 0
for eg_id, hists in enumerate(histories): for eg_id, hists in enumerate(histories):
for loss, hist in zip(losses[eg_id], hists): for loss, hist in zip(losses[eg_id], hists):
if abs(loss) >= 0.0001 and not numpy.isnan(loss): if loss != 0.0 and not numpy.isnan(loss):
nr_step = max(nr_step, len(hist)) nr_step = max(nr_step, len(hist))
for i in range(nr_step): for i in range(nr_step):
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class), dtype='f')) grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class), dtype='f'))
assert len(histories) == len(losses) assert len(histories) == len(losses)
for eg_id, hists in enumerate(histories): for eg_id, hists in enumerate(histories):
for loss, hist in zip(losses[eg_id], hists): for loss, hist in zip(losses[eg_id], hists):
if abs(loss) < 0.0001 or numpy.isnan(loss): if abs(loss) == 0.0 or numpy.isnan(loss):
continue continue
key = tuple([eg_id]) key = tuple([eg_id])
for j, clas in enumerate(hist): for j, clas in enumerate(hist):
i = beam_maps[j][key] i = beam_maps[j][key]
# In step j, at state i action clas # In step j, at state i action clas
# resulted in loss # resulted in loss
grads[j][i, clas] += loss / len(histories) grads[j][i, clas] += loss
key = key + tuple([clas]) key = key + tuple([clas])
return grads return grads

View File

@ -34,7 +34,6 @@ from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context from ._parse_features cimport fill_context
from .stateclass cimport StateClass from .stateclass cimport StateClass
from .parser cimport Parser from .parser cimport Parser
from ._beam_utils import is_gold
DEBUG = False DEBUG = False
@ -108,7 +107,7 @@ cdef class BeamParser(Parser):
# The non-monotonic oracle makes it difficult to ensure final costs are # The non-monotonic oracle makes it difficult to ensure final costs are
# correct. Therefore do final correction # correct. Therefore do final correction
for i in range(pred.size): for i in range(pred.size):
if is_gold(<StateClass>pred.at(i), gold_parse, self.moves.strings): if self.moves.is_gold_parse(<StateClass>pred.at(i), gold_parse):
pred._states[i].loss = 0.0 pred._states[i].loss = 0.0
elif pred._states[i].loss == 0.0: elif pred._states[i].loss == 0.0:
pred._states[i].loss = 1.0 pred._states[i].loss = 1.0
@ -214,7 +213,7 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio
if not pred._states[i].is_done or pred._states[i].loss == 0: if not pred._states[i].is_done or pred._states[i].loss == 0:
continue continue
state = <StateClass>pred.at(i) state = <StateClass>pred.at(i)
if is_gold(state, gold_parse, moves.strings) == True: if moves.is_gold_parse(state, gold_parse) == True:
for dep in gold_parse.orig_annot: for dep in gold_parse.orig_annot:
print(dep[1], dep[3], dep[4]) print(dep[1], dep[3], dep[4])
print("Cost", pred._states[i].loss) print("Cost", pred._states[i].loss)
@ -228,7 +227,7 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio
if not gold._states[i].is_done: if not gold._states[i].is_done:
continue continue
state = <StateClass>gold.at(i) state = <StateClass>gold.at(i)
if is_gold(state, gold_parse, moves.strings) == False: if moves.is_gold(state, gold_parse) == False:
print("Truth") print("Truth")
for dep in gold_parse.orig_annot: for dep in gold_parse.orig_annot:
print(dep[1], dep[3], dep[4]) print(dep[1], dep[3], dep[4])

View File

@ -38,6 +38,7 @@ from preshed.maps cimport map_get
from thinc.api import layerize, chain, noop, clone from thinc.api import layerize, chain, noop, clone
from thinc.neural import Model, Affine, ReLu, Maxout from thinc.neural import Model, Affine, ReLu, Maxout
from thinc.neural._classes.batchnorm import BatchNorm as BN
from thinc.neural._classes.selu import SELU from thinc.neural._classes.selu import SELU
from thinc.neural._classes.layernorm import LayerNorm from thinc.neural._classes.layernorm import LayerNorm
from thinc.neural.ops import NumpyOps, CupyOps from thinc.neural.ops import NumpyOps, CupyOps
@ -258,7 +259,7 @@ cdef class Parser:
with Model.use_device('cpu'): with Model.use_device('cpu'):
upper = chain( upper = chain(
clone(Residual(ReLu(hidden_width)), (depth-1)), clone(Maxout(hidden_width), (depth-1)),
zero_init(Affine(nr_class, drop_factor=0.0)) zero_init(Affine(nr_class, drop_factor=0.0))
) )
# TODO: This is an unfortunate hack atm! # TODO: This is an unfortunate hack atm!
@ -321,6 +322,8 @@ cdef class Parser:
beam_width = self.cfg.get('beam_width', 1) beam_width = self.cfg.get('beam_width', 1)
if beam_density is None: if beam_density is None:
beam_density = self.cfg.get('beam_density', 0.001) beam_density = self.cfg.get('beam_density', 0.001)
if BEAM_PARSE:
beam_width = 16
cdef Beam beam cdef Beam beam
if beam_width == 1: if beam_width == 1:
states = self.parse_batch([doc], [doc.tensor]) states = self.parse_batch([doc], [doc.tensor])
@ -349,7 +352,7 @@ cdef class Parser:
Yields (Doc): Documents, in order. Yields (Doc): Documents, in order.
""" """
if BEAM_PARSE: if BEAM_PARSE:
beam_width = 8 beam_width = 16
cdef Doc doc cdef Doc doc
cdef Beam beam cdef Beam beam
for docs in cytoolz.partition_all(batch_size, docs): for docs in cytoolz.partition_all(batch_size, docs):
@ -427,7 +430,7 @@ cdef class Parser:
next_step.push_back(st) next_step.push_back(st)
return states return states
def beam_parse(self, docs, tokvecses, int beam_width=8, float beam_density=0.001): def beam_parse(self, docs, tokvecses, int beam_width=16, float beam_density=0.001):
cdef Beam beam cdef Beam beam
cdef np.ndarray scores cdef np.ndarray scores
cdef Doc doc cdef Doc doc
@ -471,13 +474,13 @@ cdef class Parser:
for k in range(nr_class): for k in range(nr_class):
beam.scores[i][k] = c_scores[j * scores.shape[1] + k] beam.scores[i][k] = c_scores[j * scores.shape[1] + k]
j += 1 j += 1
beam.advance(_transition_state, NULL, <void*>self.moves.c) beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
beams.append(beam) beams.append(beam)
return beams return beams
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
if BEAM_PARSE: if BEAM_PARSE and numpy.random.random() >= 0.5:
return self.update_beam(docs_tokvecs, golds, drop=drop, sgd=sgd, return self.update_beam(docs_tokvecs, golds, drop=drop, sgd=sgd,
losses=losses) losses=losses)
if losses is not None and self.name not in losses: if losses is not None and self.name not in losses:
@ -568,7 +571,7 @@ cdef class Parser:
states, tokvecs, golds, states, tokvecs, golds,
state2vec, vec2scores, state2vec, vec2scores,
drop, sgd, losses, drop, sgd, losses,
width=8) width=16)
backprop_lower = [] backprop_lower = []
for i, d_scores in enumerate(states_d_scores): for i, d_scores in enumerate(states_d_scores):
if losses is not None: if losses is not None:
@ -633,9 +636,10 @@ cdef class Parser:
xp = get_array_module(d_tokvecs) xp = get_array_module(d_tokvecs)
for ids, d_vector, bp_vector in backprops: for ids, d_vector, bp_vector in backprops:
d_state_features = bp_vector(d_vector, sgd=sgd) d_state_features = bp_vector(d_vector, sgd=sgd)
mask = (ids >= 0).reshape((ids.shape[0], ids.shape[1], 1)) mask = ids >= 0
self.model[0].ops.scatter_add(d_tokvecs, ids, d_state_features *= mask.reshape(ids.shape + (1,))
d_state_features * mask) self.model[0].ops.scatter_add(d_tokvecs, ids * mask,
d_state_features)
@property @property
def move_names(self): def move_names(self):
@ -651,7 +655,7 @@ cdef class Parser:
lower, stream, drop=dropout) lower, stream, drop=dropout)
return state2vec, upper return state2vec, upper
nr_feature = 13 nr_feature = 8
def get_token_ids(self, states): def get_token_ids(self, states):
cdef StateClass state cdef StateClass state