diff --git a/spacy/syntax/beam_parser.pyx b/spacy/syntax/beam_parser.pyx index 9091e539f..5ecd26283 100644 --- a/spacy/syntax/beam_parser.pyx +++ b/spacy/syntax/beam_parser.pyx @@ -72,7 +72,7 @@ def get_templates(name): cdef int BEAM_WIDTH = 16 -cdef weight_t BEAM_DENSITY = 0.001 +cdef weight_t BEAM_DENSITY = 0.01 cdef class BeamParser(Parser): cdef public int beam_width @@ -104,7 +104,7 @@ cdef class BeamParser(Parser): pred.initialize(_init_state, tokens.length, tokens.c) pred.check_done(_check_final_state, NULL) - cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density) + cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=0.0) gold.initialize(_init_state, tokens.length, tokens.c) gold.check_done(_check_final_state, NULL) violn = MaxViolation() @@ -116,14 +116,22 @@ cdef class BeamParser(Parser): if pred.loss > 0 and pred.min_score > (gold.score + self.model.time): break else: + # The non-monotonic oracle makes it difficult to ensure final costs are + # correct. Therefore do final correction + for i in range(pred.size): + if is_gold(pred.at(i), gold_parse, self.moves.strings): + pred._states[i].loss = 0.0 + elif pred._states[i].loss == 0.0: + pred._states[i].loss = 1.0 violn.check_crf(pred, gold) - min_grad = 0.001 ** (itn+1) + _check_train_integrity(pred, gold, gold_parse, self.moves) histories = zip(violn.p_probs, violn.p_hist) + zip(violn.g_probs, violn.g_hist) + min_grad = 0.001 ** (itn+1) + histories = [(grad, hist) for grad, hist in histories if abs(grad) >= min_grad] random.shuffle(histories) for grad, hist in histories: assert not math.isnan(grad) and not math.isinf(grad) - if abs(grad) >= min_grad: - self.model._update_from_history(self.moves, tokens, hist, grad) + self.model._update_from_history(self.moves, tokens, hist, grad) _cleanup(pred) _cleanup(gold) return pred.loss @@ -131,25 +139,26 @@ cdef class BeamParser(Parser): def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): cdef Pool mem = Pool() features = mem.alloc(self.model.nr_feat, sizeof(FeatureC)) - cdef ParserNeuralNet nn_model = None - cdef ParserPerceptron ap_model = None if isinstance(self.model, ParserNeuralNet): - nn_model = self.model + mb = Minibatch(self.model.widths, beam.size) + for i in range(beam.size): + stcls = beam.at(i) + if stcls.c.is_final(): + nr_feat = 0 + else: + nr_feat = self.model.set_featuresC(features, stcls.c) + self.moves.set_valid(beam.is_valid[i], stcls.c) + mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0) + self.model(mb) + for i in range(beam.size): + memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0])) else: - ap_model = self.model - raise NotImplementedError - cdef Minibatch mb = Minibatch(nn_model.widths, beam.size) - for i in range(beam.size): - stcls = beam.at(i) - if stcls.c.is_final(): - nr_feat = 0 - else: - nr_feat = nn_model._set_featuresC(features, stcls.c) - self.moves.set_valid(beam.is_valid[i], stcls.c) - mb.c.push_back(features, nr_feat, beam.costs[i], beam.is_valid[i], 0) - self.model(mb) - for i in range(beam.size): - memcpy(beam.scores[i], mb.c.scores(i), mb.c.nr_out() * sizeof(beam.scores[i][0])) + for i in range(beam.size): + stcls = beam.at(i) + if not stcls.c.is_final(): + nr_feat = self.model.set_featuresC(features, stcls.c) + self.moves.set_valid(beam.is_valid[i], stcls.c) + self.model.set_scoresC(beam.scores[i], features, nr_feat) if gold is not None: for i in range(beam.size): stcls = beam.at(i) @@ -158,7 +167,10 @@ cdef class BeamParser(Parser): if follow_gold: for j in range(self.moves.n_moves): beam.is_valid[i][j] *= beam.costs[i][j] < 1 - beam.advance(_transition_state, _hash_state, self.moves.c) + if follow_gold: + beam.advance(_transition_state, NULL, self.moves.c) + else: + beam.advance(_transition_state, _hash_state, self.moves.c) beam.check_done(_check_final_state, NULL) @@ -195,4 +207,51 @@ def _cleanup(Beam beam): cdef hash_t _hash_state(void* _state, void* _) except 0: state = _state - return state.c.hash() + if state.c.is_final(): + return 1 + else: + return state.c.hash() + + +def _check_train_integrity(Beam pred, Beam gold, GoldParse gold_parse, TransitionSystem moves): + for i in range(pred.size): + if not pred._states[i].is_done or pred._states[i].loss == 0: + continue + state = pred.at(i) + if is_gold(state, gold_parse, moves.strings) == True: + print("Truth") + for dep in gold_parse.orig_annot: + print(dep[1], dep[3], dep[4]) + print("Cost", pred._states[i].loss) + for j in range(gold_parse.length): + print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep]) + acts = [moves.c[clas].move for clas in pred.histories[i]] + labels = [moves.c[clas].label for clas in pred.histories[i]] + print([moves.move_name(move, label) for move, label in zip(acts, labels)]) + raise Exception("Predicted state is gold-standard") + for i in range(gold.size): + if not gold._states[i].is_done: + continue + state = gold.at(i) + if is_gold(state, gold_parse, moves.strings) == False: + print("Truth") + for dep in gold_parse.orig_annot: + print(dep[1], dep[3], dep[4]) + print("Predicted good") + for j in range(gold_parse.length): + print(gold_parse.orig_annot[j][1], state.H(j), moves.strings[state.safe_get(j).dep]) + raise Exception("Gold parse is not gold-standard") + + +def is_gold(StateClass state, GoldParse gold, StringStore strings): + predicted = set() + truth = set() + for i in range(gold.length): + if state.safe_get(i).dep: + predicted.add((i, state.H(i), strings[state.safe_get(i).dep])) + else: + predicted.add((i, state.H(i), 'ROOT')) + id_, word, tag, head, dep, ner = gold.orig_annot[i] + truth.add((id_, head, dep)) + return truth == predicted +