From 6e2564239d97caeada75a034e7674844009b589f Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 7 Jun 2015 19:12:59 +0200 Subject: [PATCH] * Bug fixes to beam parser. Search still broken on non-gold sentences --- spacy/syntax/parser.pyx | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 5fc0be0f9..639f91c03 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -83,7 +83,7 @@ cdef class Parser: def __call__(self, Tokens tokens): if tokens.length == 0: return 0 - if self.cfg.beam_width <= 1: + if self.cfg.get('beam_width', 1) <= 1: self._greedy_parse(tokens) else: self._beam_parse(tokens) @@ -113,6 +113,7 @@ cdef class Parser: cdef int _beam_parse(self, Tokens tokens) except -1: cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width) beam.initialize(_init_state, tokens.length, tokens.data) + beam.check_done(_check_final_state, NULL) while not beam.is_done: self._advance_beam(beam, None, False) state = beam.at(0) @@ -145,8 +146,10 @@ cdef class Parser: def _beam_train(self, Tokens tokens, GoldParse gold_parse): cdef Beam pred = Beam(self.moves.n_moves, self.cfg.beam_width) pred.initialize(_init_state, tokens.length, tokens.data) + pred.check_done(_check_final_state, NULL) cdef Beam gold = Beam(self.moves.n_moves, self.cfg.beam_width) gold.initialize(_init_state, tokens.length, tokens.data) + gold.check_done(_check_final_state, NULL) violn = MaxViolation() while not pred.is_done and not gold.is_done: @@ -170,9 +173,10 @@ cdef class Parser: cdef const Transition* move for i in range(beam.size): state = beam.at(i) - fill_context(context, state) - self.model.set_scores(beam.scores[i], context) - self.moves.set_valid(beam.is_valid[i], state) + if not is_final(state): + fill_context(context, state) + self.model.set_scores(beam.scores[i], context) + self.moves.set_valid(beam.is_valid[i], state) if gold is not None: for i in range(beam.size): @@ -194,8 +198,6 @@ cdef class Parser: cdef class_t clas cdef int n_feats for clas in hist: - if is_final(state): - break fill_context(context, state) feats = self.model._extractor.get_feats(context, &n_feats) count_feats(counts[clas], feats, n_feats, inc)