diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx index 48030b72a..7afe51d4f 100644 --- a/spacy/syntax/_beam_utils.pyx +++ b/spacy/syntax/_beam_utils.pyx @@ -6,6 +6,7 @@ from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF from thinc.extra.search cimport Beam from thinc.extra.search import MaxViolation from thinc.typedefs cimport hash_t, class_t +from thinc.extra.search cimport MaxViolation from .transition_system cimport TransitionSystem, Transition from .stateclass cimport StateClass @@ -45,6 +46,7 @@ cdef class ParserBeam(object): cdef public object states cdef public object golds cdef public object beams + cdef public object dones def __init__(self, TransitionSystem moves, states, golds, int width=4, float density=0.001): @@ -61,6 +63,7 @@ cdef class ParserBeam(object): st = beam.at(i) st.c.offset = state.c.offset self.beams.append(beam) + self.dones = [False] * len(self.beams) def __dealloc__(self): if self.beams is not None: @@ -70,7 +73,7 @@ cdef class ParserBeam(object): @property def is_done(self): - return all(b.is_done for b in self.beams) + return all(b.is_done or self.dones[i] for i, b in enumerate(self.beams)) def __getitem__(self, i): return self.beams[i] @@ -81,19 +84,24 @@ cdef class ParserBeam(object): def advance(self, scores, follow_gold=False): cdef Beam beam for i, beam in enumerate(self.beams): - if beam.is_done or not scores[i].size: + if beam.is_done or not scores[i].size or self.dones[i]: continue self._set_scores(beam, scores[i]) if self.golds is not None: self._set_costs(beam, self.golds[i], follow_gold=follow_gold) - beam.advance(_transition_state, NULL, self.moves.c) + beam.advance(_transition_state, _hash_state, self.moves.c) beam.check_done(_check_final_state, NULL) - if beam.is_done: + if beam.is_done and self.golds is not None: for j in range(beam.size): - if is_gold(beam.at(j), self.golds[i], self.moves.strings): - beam._states[j].loss = 0.0 - elif beam._states[j].loss == 0.0: - beam._states[j].loss = 1.0 + state = beam.at(j) + if state.is_final(): + try: + if self.moves.is_gold_parse(state, self.golds[i]): + beam._states[j].loss = 0.0 + elif beam._states[j].loss == 0.0: + beam._states[j].loss = 1.0 + except NotImplementedError: + break def _set_scores(self, Beam beam, float[:, ::1] scores): cdef float* c_scores = &scores[0, 0] @@ -110,7 +118,6 @@ cdef class ParserBeam(object): beam.scores[i][j] = 0 beam.costs[i][j] = 0 - def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False): for i in range(beam.size): state = beam.at(i) @@ -122,21 +129,6 @@ cdef class ParserBeam(object): beam.is_valid[i][j] = 0 -def is_gold(StateClass state, GoldParse gold, strings): - predicted = set() - truth = set() - for i in range(gold.length): - if gold.cand_to_gold[i] is None: - continue - if state.safe_get(i).dep: - predicted.add((i, state.H(i), strings[state.safe_get(i).dep])) - else: - predicted.add((i, state.H(i), 'ROOT')) - id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]] - truth.add((id_, head, dep)) - return truth == predicted - - def get_token_ids(states, int n_tokens): cdef StateClass state cdef np.ndarray ids = numpy.zeros((len(states), n_tokens), @@ -156,16 +148,19 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, state2vec, vec2scores, drop=0., sgd=None, losses=None, int width=4, float density=0.001): global nr_update + cdef MaxViolation violn nr_update += 1 pbeam = ParserBeam(moves, states, golds, width=width, density=density) gbeam = ParserBeam(moves, states, golds, - width=width, density=0.0) + width=width, density=density) cdef StateClass state beam_maps = [] backprops = [] violns = [MaxViolation() for _ in range(len(states))] for t in range(max_steps): + if pbeam.is_done and gbeam.is_done: + break # The beam maps let us find the right row in the flattened scores # arrays for each state. States are identified by (example id, history). # We keep a different beam map for each step (since we'll have a flat @@ -197,12 +192,16 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, # Track the "maximum violation", to use in the update. for i, violn in enumerate(violns): violn.check_crf(pbeam[i], gbeam[i]) - - # Only make updates if we have non-gold states - histories = [((v.p_hist + v.g_hist) if v.p_hist else []) for v in violns] - losses = [((v.p_probs + v.g_probs) if v.p_probs else []) for v in violns] - states_d_scores = get_gradient(moves.n_moves, beam_maps, - histories, losses) + histories = [] + losses = [] + for i, violn in enumerate(violns): + if violn.cost < 1: + histories.append([]) + losses.append([]) + else: + histories.append(violn.p_hist + violn.g_hist) + losses.append(violn.p_probs + violn.g_probs) + states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, losses) return states_d_scores, backprops[:len(states_d_scores)] @@ -216,7 +215,9 @@ def get_states(pbeams, gbeams, beam_map, nr_update): for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)): p_indices.append([]) g_indices.append([]) - if pbeam.loss > 0 and pbeam.min_score > (gbeam.score + nr_update): + if pbeam.loss > 0 and pbeam.min_score > (gbeam.score + numpy.sqrt(nr_update)): + pbeams.dones[eg_id] = True + gbeams.dones[eg_id] = True continue for i in range(pbeam.size): state = pbeam.at(i) @@ -261,21 +262,21 @@ def get_gradient(nr_class, beam_maps, histories, losses): nr_step = 0 for eg_id, hists in enumerate(histories): for loss, hist in zip(losses[eg_id], hists): - if abs(loss) >= 0.0001 and not numpy.isnan(loss): + if loss != 0.0 and not numpy.isnan(loss): nr_step = max(nr_step, len(hist)) for i in range(nr_step): grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class), dtype='f')) assert len(histories) == len(losses) for eg_id, hists in enumerate(histories): for loss, hist in zip(losses[eg_id], hists): - if abs(loss) < 0.0001 or numpy.isnan(loss): + if abs(loss) == 0.0 or numpy.isnan(loss): continue key = tuple([eg_id]) for j, clas in enumerate(hist): i = beam_maps[j][key] # In step j, at state i action clas # resulted in loss - grads[j][i, clas] += loss / len(histories) + grads[j][i, clas] += loss key = key + tuple([clas]) return grads diff --git a/spacy/syntax/beam_parser.pyx b/spacy/syntax/beam_parser.pyx index f4f66f9fb..68e9f27af 100644 --- a/spacy/syntax/beam_parser.pyx +++ b/spacy/syntax/beam_parser.pyx @@ -34,7 +34,6 @@ from ._parse_features cimport CONTEXT_SIZE from ._parse_features cimport fill_context from .stateclass cimport StateClass from .parser cimport Parser -from ._beam_utils import is_gold DEBUG = False @@ -108,7 +107,7 @@ cdef class BeamParser(Parser): # The non-monotonic oracle makes it difficult to ensure final costs are # correct. Therefore do final correction for i in range(pred.size): - if is_gold(pred.at(i), gold_parse, self.moves.strings): + if self.moves.is_gold_parse(pred.at(i), gold_parse): pred._states[i].loss = 0.0 elif pred._states[i].loss == 0.0: pred._states[i].loss = 1.0 @@ -214,7 +213,7 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio if not pred._states[i].is_done or pred._states[i].loss == 0: continue state = pred.at(i) - if is_gold(state, gold_parse, moves.strings) == True: + if moves.is_gold_parse(state, gold_parse) == True: for dep in gold_parse.orig_annot: print(dep[1], dep[3], dep[4]) print("Cost", pred._states[i].loss) @@ -228,7 +227,7 @@ def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, Transitio if not gold._states[i].is_done: continue state = gold.at(i) - if is_gold(state, gold_parse, moves.strings) == False: + if moves.is_gold(state, gold_parse) == False: print("Truth") for dep in gold_parse.orig_annot: print(dep[1], dep[3], dep[4]) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 9a35c69d7..11fc4e742 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -38,6 +38,7 @@ from preshed.maps cimport map_get from thinc.api import layerize, chain, noop, clone from thinc.neural import Model, Affine, ReLu, Maxout +from thinc.neural._classes.batchnorm import BatchNorm as BN from thinc.neural._classes.selu import SELU from thinc.neural._classes.layernorm import LayerNorm from thinc.neural.ops import NumpyOps, CupyOps @@ -258,7 +259,7 @@ cdef class Parser: with Model.use_device('cpu'): upper = chain( - clone(Residual(ReLu(hidden_width)), (depth-1)), + clone(Maxout(hidden_width), (depth-1)), zero_init(Affine(nr_class, drop_factor=0.0)) ) # TODO: This is an unfortunate hack atm! @@ -321,6 +322,8 @@ cdef class Parser: beam_width = self.cfg.get('beam_width', 1) if beam_density is None: beam_density = self.cfg.get('beam_density', 0.001) + if BEAM_PARSE: + beam_width = 16 cdef Beam beam if beam_width == 1: states = self.parse_batch([doc], [doc.tensor]) @@ -349,7 +352,7 @@ cdef class Parser: Yields (Doc): Documents, in order. """ if BEAM_PARSE: - beam_width = 8 + beam_width = 16 cdef Doc doc cdef Beam beam for docs in cytoolz.partition_all(batch_size, docs): @@ -427,7 +430,7 @@ cdef class Parser: next_step.push_back(st) return states - def beam_parse(self, docs, tokvecses, int beam_width=8, float beam_density=0.001): + def beam_parse(self, docs, tokvecses, int beam_width=16, float beam_density=0.001): cdef Beam beam cdef np.ndarray scores cdef Doc doc @@ -471,13 +474,13 @@ cdef class Parser: for k in range(nr_class): beam.scores[i][k] = c_scores[j * scores.shape[1] + k] j += 1 - beam.advance(_transition_state, NULL, self.moves.c) + beam.advance(_transition_state, _hash_state, self.moves.c) beam.check_done(_check_final_state, NULL) beams.append(beam) return beams def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): - if BEAM_PARSE: + if BEAM_PARSE and numpy.random.random() >= 0.5: return self.update_beam(docs_tokvecs, golds, drop=drop, sgd=sgd, losses=losses) if losses is not None and self.name not in losses: @@ -568,7 +571,7 @@ cdef class Parser: states, tokvecs, golds, state2vec, vec2scores, drop, sgd, losses, - width=8) + width=16) backprop_lower = [] for i, d_scores in enumerate(states_d_scores): if losses is not None: @@ -633,9 +636,10 @@ cdef class Parser: xp = get_array_module(d_tokvecs) for ids, d_vector, bp_vector in backprops: d_state_features = bp_vector(d_vector, sgd=sgd) - mask = (ids >= 0).reshape((ids.shape[0], ids.shape[1], 1)) - self.model[0].ops.scatter_add(d_tokvecs, ids, - d_state_features * mask) + mask = ids >= 0 + d_state_features *= mask.reshape(ids.shape + (1,)) + self.model[0].ops.scatter_add(d_tokvecs, ids * mask, + d_state_features) @property def move_names(self): @@ -651,7 +655,7 @@ cdef class Parser: lower, stream, drop=dropout) return state2vec, upper - nr_feature = 13 + nr_feature = 8 def get_token_ids(self, states): cdef StateClass state