From 6a42cc16ff673c738e29aa515c1623dde4cf9566 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 13 Aug 2017 12:37:26 +0200 Subject: [PATCH] Fix beam parser, improve efficiency of non-beam --- spacy/syntax/_beam_utils.pyx | 39 ++++++++++++------------------------ spacy/syntax/beam_parser.pyx | 14 +------------ spacy/syntax/nn_parser.pyx | 38 +++++++++++++++++++++++------------ 3 files changed, 39 insertions(+), 52 deletions(-) diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx index 0a513531d..6df8d472f 100644 --- a/spacy/syntax/_beam_utils.pyx +++ b/spacy/syntax/_beam_utils.pyx @@ -1,4 +1,5 @@ # cython: infer_types=True +# cython: profile=True cimport numpy as np import numpy from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF @@ -155,8 +156,6 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, backprops = [] violns = [MaxViolation() for _ in range(len(states))] for t in range(max_steps): - if pbeam.is_done and gbeam.is_done: - break beam_maps.append({}) states, p_indices, g_indices = get_states(pbeam, gbeam, beam_maps[-1], nr_update) if not states: @@ -174,16 +173,6 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, for i, violn in enumerate(violns): violn.check_crf(pbeam[i], gbeam[i]) - # The non-monotonic oracle makes it difficult to ensure final costs are - # correct. Therefore do final correction - cdef Beam pred - for i, (pred, gold_parse) in enumerate(zip(pbeam, golds)): - for j in range(pred.size): - if is_gold(pred.at(j), gold_parse, moves.strings): - pred._states[j].loss = 0.0 - elif pred._states[j].loss == 0.0: - pred._states[j].loss = 1.0 - violn.check_crf(pred, gbeam[i]) histories = [(v.p_hist + v.g_hist) for v in violns] losses = [(v.p_probs + v.g_probs) for v in violns] @@ -199,20 +188,18 @@ def get_states(pbeams, gbeams, beam_map, nr_update): g_indices = [] cdef Beam pbeam, gbeam for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)): - if pbeam.loss > 0 and pbeam.min_score > (gbeam.score + nr_update): - continue p_indices.append([]) - for j in range(pbeam.size): - state = pbeam.at(j) + for i in range(pbeam.size): + state = pbeam.at(i) if not state.is_final(): - key = tuple([eg_id] + pbeam.histories[j]) + key = tuple([eg_id] + pbeam.histories[i]) seen[key] = len(states) p_indices[-1].append(len(states)) - states.append(pbeam.at(j)) + states.append(pbeam.at(i)) beam_map.update(seen) g_indices.append([]) for i in range(gbeam.size): - state = gbeam.at(j) + state = gbeam.at(i) if not state.is_final(): key = tuple([eg_id] + gbeam.histories[i]) if key in seen: @@ -243,17 +230,17 @@ def get_gradient(nr_class, beam_maps, histories, losses): nr_step = len(beam_maps) grads = [] for beam_map in beam_maps: - if beam_map: - grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f')) - else: - grads.append(None) + grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f')) + assert len(histories) == len(losses) for eg_id, hists in enumerate(histories): for loss, hist in zip(losses[eg_id], hists): key = tuple([eg_id]) for j, clas in enumerate(hist): - if grads[j] is None: - continue - i = beam_maps[j][key] + try: + i = beam_maps[j][key] + except: + print(sorted(beam_maps[j].items())) + raise # In step j, at state i action clas # resulted in loss grads[j][i, clas] += loss diff --git a/spacy/syntax/beam_parser.pyx b/spacy/syntax/beam_parser.pyx index e96e28fcf..f4f66f9fb 100644 --- a/spacy/syntax/beam_parser.pyx +++ b/spacy/syntax/beam_parser.pyx @@ -34,6 +34,7 @@ from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context from .stateclass cimport StateClass from .parser cimport Parser +from ._beam_utils import is_gold DEBUG = False @@ -237,16 +238,3 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio raise Exception("Gold parse is not gold-standard") -def is_gold(StateClass state, GoldParse gold, StringStore strings): - predicted = set() - truth = set() - for i in range(gold.length): - if gold.cand_to_gold[i] is None: - continue - if state.safe_get(i).dep: - predicted.add((i, state.H(i), strings[state.safe_get(i).dep])) - else: - predicted.add((i, state.H(i), 'ROOT')) - id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]] - truth.add((id_, head, dep)) - return truth == predicted diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index ea61af1df..51fd61cc1 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -66,7 +66,7 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG from . import _beam_utils USE_FINE_TUNE = True -BEAM_PARSE = False +BEAM_PARSE = True def get_templates(*args, **kwargs): return [] @@ -348,6 +348,8 @@ cdef class Parser: The number of threads with which to work on the buffer in parallel. Yields (Doc): Documents, in order. """ + if BEAM_PARSE: + beam_width = 8 cdef Doc doc cdef Beam beam for docs in cytoolz.partition_all(batch_size, docs): @@ -439,6 +441,8 @@ cdef class Parser: cuda_stream, 0.0) beams = [] cdef int offset = 0 + cdef int j = 0 + cdef int k for doc in docs: beam = Beam(nr_class, beam_width, min_density=beam_density) beam.initialize(self.moves.init_beam_state, doc.length, doc.c) @@ -451,16 +455,22 @@ cdef class Parser: states = [] for i in range(beam.size): stcls = beam.at(i) - states.append(stcls) + # This way we avoid having to score finalized states + # We do have to take care to keep indexes aligned, though + if not stcls.is_final(): + states.append(stcls) token_ids = self.get_token_ids(states) vectors = state2vec(token_ids) scores = vec2scores(vectors) + j = 0 + c_scores = scores.data for i in range(beam.size): stcls = beam.at(i) if not stcls.is_final(): self.moves.set_valid(beam.is_valid[i], stcls.c) - for j in range(nr_class): - beam.scores[i][j] = scores[i, j] + for k in range(nr_class): + beam.scores[i][k] = c_scores[j * scores.shape[1] + k] + j += 1 beam.advance(_transition_state, _hash_state, self.moves.c) beam.check_done(_check_final_state, NULL) beams.append(beam) @@ -540,6 +550,7 @@ cdef class Parser: losses[self.name] = 0. docs, tokvecs = docs_tokvecs lengths = [len(d) for d in docs] + assert min(lengths) >= 1 tokvecs = self.model[0].ops.flatten(tokvecs) if USE_FINE_TUNE: my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) @@ -554,9 +565,14 @@ cdef class Parser: states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, max_moves, states, tokvecs, golds, state2vec, vec2scores, - drop, sgd, losses) + drop, sgd, losses, + width=8) backprop_lower = [] for i, d_scores in enumerate(states_d_scores): + if d_scores is None: + continue + if losses is not None: + losses[self.name] += (d_scores**2).sum() ids, bp_vectors, bp_scores = backprops[i] d_vector = bp_scores(d_scores, sgd=sgd) if isinstance(self.model[0].ops, CupyOps) \ @@ -617,14 +633,10 @@ cdef class Parser: xp = get_array_module(d_tokvecs) for ids, d_vector, bp_vector in backprops: d_state_features = bp_vector(d_vector, sgd=sgd) - active_feats = ids * (ids >= 0) - active_feats = active_feats.reshape((ids.shape[0], ids.shape[1], 1)) - if hasattr(xp, 'scatter_add'): - xp.scatter_add(d_tokvecs, - ids, d_state_features * active_feats) - else: - xp.add.at(d_tokvecs, - ids, d_state_features * active_feats) + mask = ids >= 0 + indices = xp.nonzero(mask) + self.model[0].ops.scatter_add(d_tokvecs, ids[indices], + d_state_features[indices]) @property def move_names(self):