Fix beam parser, improve efficiency of non-beam

This commit is contained in:
Matthew Honnibal 2017-08-13 12:37:26 +02:00
parent 4363b4aa4a
commit 6a42cc16ff
3 changed files with 39 additions and 52 deletions

View File

@ -1,4 +1,5 @@
# cython: infer_types=True # cython: infer_types=True
# cython: profile=True
cimport numpy as np cimport numpy as np
import numpy import numpy
from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
@ -155,8 +156,6 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
backprops = [] backprops = []
violns = [MaxViolation() for _ in range(len(states))] violns = [MaxViolation() for _ in range(len(states))]
for t in range(max_steps): for t in range(max_steps):
if pbeam.is_done and gbeam.is_done:
break
beam_maps.append({}) beam_maps.append({})
states, p_indices, g_indices = get_states(pbeam, gbeam, beam_maps[-1], nr_update) states, p_indices, g_indices = get_states(pbeam, gbeam, beam_maps[-1], nr_update)
if not states: if not states:
@ -174,16 +173,6 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
for i, violn in enumerate(violns): for i, violn in enumerate(violns):
violn.check_crf(pbeam[i], gbeam[i]) violn.check_crf(pbeam[i], gbeam[i])
# The non-monotonic oracle makes it difficult to ensure final costs are
# correct. Therefore do final correction
cdef Beam pred
for i, (pred, gold_parse) in enumerate(zip(pbeam, golds)):
for j in range(pred.size):
if is_gold(<StateClass>pred.at(j), gold_parse, moves.strings):
pred._states[j].loss = 0.0
elif pred._states[j].loss == 0.0:
pred._states[j].loss = 1.0
violn.check_crf(pred, gbeam[i])
histories = [(v.p_hist + v.g_hist) for v in violns] histories = [(v.p_hist + v.g_hist) for v in violns]
losses = [(v.p_probs + v.g_probs) for v in violns] losses = [(v.p_probs + v.g_probs) for v in violns]
@ -199,20 +188,18 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
g_indices = [] g_indices = []
cdef Beam pbeam, gbeam cdef Beam pbeam, gbeam
for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)): for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)):
if pbeam.loss > 0 and pbeam.min_score > (gbeam.score + nr_update):
continue
p_indices.append([]) p_indices.append([])
for j in range(pbeam.size): for i in range(pbeam.size):
state = <StateClass>pbeam.at(j) state = <StateClass>pbeam.at(i)
if not state.is_final(): if not state.is_final():
key = tuple([eg_id] + pbeam.histories[j]) key = tuple([eg_id] + pbeam.histories[i])
seen[key] = len(states) seen[key] = len(states)
p_indices[-1].append(len(states)) p_indices[-1].append(len(states))
states.append(<StateClass>pbeam.at(j)) states.append(<StateClass>pbeam.at(i))
beam_map.update(seen) beam_map.update(seen)
g_indices.append([]) g_indices.append([])
for i in range(gbeam.size): for i in range(gbeam.size):
state = <StateClass>gbeam.at(j) state = <StateClass>gbeam.at(i)
if not state.is_final(): if not state.is_final():
key = tuple([eg_id] + gbeam.histories[i]) key = tuple([eg_id] + gbeam.histories[i])
if key in seen: if key in seen:
@ -243,17 +230,17 @@ def get_gradient(nr_class, beam_maps, histories, losses):
nr_step = len(beam_maps) nr_step = len(beam_maps)
grads = [] grads = []
for beam_map in beam_maps: for beam_map in beam_maps:
if beam_map: grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f'))
grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f')) assert len(histories) == len(losses)
else:
grads.append(None)
for eg_id, hists in enumerate(histories): for eg_id, hists in enumerate(histories):
for loss, hist in zip(losses[eg_id], hists): for loss, hist in zip(losses[eg_id], hists):
key = tuple([eg_id]) key = tuple([eg_id])
for j, clas in enumerate(hist): for j, clas in enumerate(hist):
if grads[j] is None: try:
continue i = beam_maps[j][key]
i = beam_maps[j][key] except:
print(sorted(beam_maps[j].items()))
raise
# In step j, at state i action clas # In step j, at state i action clas
# resulted in loss # resulted in loss
grads[j][i, clas] += loss grads[j][i, clas] += loss

View File

@ -34,6 +34,7 @@ from ._parse_features cimport CONTEXT_SIZE
from ._parse_features cimport fill_context from ._parse_features cimport fill_context
from .stateclass cimport StateClass from .stateclass cimport StateClass
from .parser cimport Parser from .parser cimport Parser
from ._beam_utils import is_gold
DEBUG = False DEBUG = False
@ -237,16 +238,3 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio
raise Exception("Gold parse is not gold-standard") raise Exception("Gold parse is not gold-standard")
def is_gold(StateClass state, GoldParse gold, StringStore strings):
predicted = set()
truth = set()
for i in range(gold.length):
if gold.cand_to_gold[i] is None:
continue
if state.safe_get(i).dep:
predicted.add((i, state.H(i), strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
truth.add((id_, head, dep))
return truth == predicted

View File

@ -66,7 +66,7 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
from . import _beam_utils from . import _beam_utils
USE_FINE_TUNE = True USE_FINE_TUNE = True
BEAM_PARSE = False BEAM_PARSE = True
def get_templates(*args, **kwargs): def get_templates(*args, **kwargs):
return [] return []
@ -348,6 +348,8 @@ cdef class Parser:
The number of threads with which to work on the buffer in parallel. The number of threads with which to work on the buffer in parallel.
Yields (Doc): Documents, in order. Yields (Doc): Documents, in order.
""" """
if BEAM_PARSE:
beam_width = 8
cdef Doc doc cdef Doc doc
cdef Beam beam cdef Beam beam
for docs in cytoolz.partition_all(batch_size, docs): for docs in cytoolz.partition_all(batch_size, docs):
@ -439,6 +441,8 @@ cdef class Parser:
cuda_stream, 0.0) cuda_stream, 0.0)
beams = [] beams = []
cdef int offset = 0 cdef int offset = 0
cdef int j = 0
cdef int k
for doc in docs: for doc in docs:
beam = Beam(nr_class, beam_width, min_density=beam_density) beam = Beam(nr_class, beam_width, min_density=beam_density)
beam.initialize(self.moves.init_beam_state, doc.length, doc.c) beam.initialize(self.moves.init_beam_state, doc.length, doc.c)
@ -451,16 +455,22 @@ cdef class Parser:
states = [] states = []
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) stcls = <StateClass>beam.at(i)
states.append(stcls) # This way we avoid having to score finalized states
# We do have to take care to keep indexes aligned, though
if not stcls.is_final():
states.append(stcls)
token_ids = self.get_token_ids(states) token_ids = self.get_token_ids(states)
vectors = state2vec(token_ids) vectors = state2vec(token_ids)
scores = vec2scores(vectors) scores = vec2scores(vectors)
j = 0
c_scores = <float*>scores.data
for i in range(beam.size): for i in range(beam.size):
stcls = <StateClass>beam.at(i) stcls = <StateClass>beam.at(i)
if not stcls.is_final(): if not stcls.is_final():
self.moves.set_valid(beam.is_valid[i], stcls.c) self.moves.set_valid(beam.is_valid[i], stcls.c)
for j in range(nr_class): for k in range(nr_class):
beam.scores[i][j] = scores[i, j] beam.scores[i][k] = c_scores[j * scores.shape[1] + k]
j += 1
beam.advance(_transition_state, _hash_state, <void*>self.moves.c) beam.advance(_transition_state, _hash_state, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL) beam.check_done(_check_final_state, NULL)
beams.append(beam) beams.append(beam)
@ -540,6 +550,7 @@ cdef class Parser:
losses[self.name] = 0. losses[self.name] = 0.
docs, tokvecs = docs_tokvecs docs, tokvecs = docs_tokvecs
lengths = [len(d) for d in docs] lengths = [len(d) for d in docs]
assert min(lengths) >= 1
tokvecs = self.model[0].ops.flatten(tokvecs) tokvecs = self.model[0].ops.flatten(tokvecs)
if USE_FINE_TUNE: if USE_FINE_TUNE:
my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop)
@ -554,9 +565,14 @@ cdef class Parser:
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, max_moves, states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, max_moves,
states, tokvecs, golds, states, tokvecs, golds,
state2vec, vec2scores, state2vec, vec2scores,
drop, sgd, losses) drop, sgd, losses,
width=8)
backprop_lower = [] backprop_lower = []
for i, d_scores in enumerate(states_d_scores): for i, d_scores in enumerate(states_d_scores):
if d_scores is None:
continue
if losses is not None:
losses[self.name] += (d_scores**2).sum()
ids, bp_vectors, bp_scores = backprops[i] ids, bp_vectors, bp_scores = backprops[i]
d_vector = bp_scores(d_scores, sgd=sgd) d_vector = bp_scores(d_scores, sgd=sgd)
if isinstance(self.model[0].ops, CupyOps) \ if isinstance(self.model[0].ops, CupyOps) \
@ -617,14 +633,10 @@ cdef class Parser:
xp = get_array_module(d_tokvecs) xp = get_array_module(d_tokvecs)
for ids, d_vector, bp_vector in backprops: for ids, d_vector, bp_vector in backprops:
d_state_features = bp_vector(d_vector, sgd=sgd) d_state_features = bp_vector(d_vector, sgd=sgd)
active_feats = ids * (ids >= 0) mask = ids >= 0
active_feats = active_feats.reshape((ids.shape[0], ids.shape[1], 1)) indices = xp.nonzero(mask)
if hasattr(xp, 'scatter_add'): self.model[0].ops.scatter_add(d_tokvecs, ids[indices],
xp.scatter_add(d_tokvecs, d_state_features[indices])
ids, d_state_features * active_feats)
else:
xp.add.at(d_tokvecs,
ids, d_state_features * active_feats)
@property @property
def move_names(self): def move_names(self):