From b3fd48c97b7d3c67a0faedb0094009b4b961c723 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 28 Apr 2015 20:45:51 +0200 Subject: [PATCH] * Fix missing root labels bug identified in Issue #57 --- spacy/syntax/arc_eager.pyx | 8 +++++++- spacy/syntax/ner.pyx | 3 --- spacy/syntax/parser.pyx | 8 ++++---- spacy/syntax/transition_system.pxd | 3 ++- spacy/syntax/transition_system.pyx | 7 +++++-- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index fb544aa3e..7d3d36347 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -88,9 +88,15 @@ cdef class ArcEager(TransitionSystem): t.get_cost = get_cost_funcs[move] return t - cdef int first_state(self, State* state) except -1: + cdef int initialize_state(self, State* state) except -1: push_stack(state) + cdef int finalize_state(self, State* state) except -1: + cdef int root_label = self.strings['ROOT'] + for i in range(state.sent_len): + if state.sent[i].head == 0 and state.sent[i].dep == 0: + state.sent[i].dep = root_label + cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef bint[N_MOVES] is_valid is_valid[SHIFT] = _can_shift(s) diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 8622d7894..0e1bc6b20 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -124,9 +124,6 @@ cdef class BiluoPushDown(TransitionSystem): t.get_cost = _get_cost return t - cdef int first_state(self, State* state) except -1: - pass - cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef int best = -1 cdef weight_t score = -90000 diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index ab9de48b8..58e98c1e1 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -83,15 +83,14 @@ cdef class GreedyParser: cdef int n_feats cdef Pool mem = Pool() cdef State* state = new_state(mem, tokens.data, tokens.length) - self.moves.first_state(state) + self.moves.initialize_state(state) cdef Transition guess while not is_final(state): fill_context(context, state) scores = self.model.score(context) guess = self.moves.best_valid(scores, state) - #print self.moves.move_name(guess.move, guess.label), - #print print_state(state, [w.orth_ for w in tokens]) guess.do(&guess, state) + self.moves.finalize_state(state) tokens.set_parse(state.sent) return 0 @@ -99,7 +98,7 @@ cdef class GreedyParser: self.moves.preprocess_gold(gold) cdef Pool mem = Pool() cdef State* state = new_state(mem, tokens.data, tokens.length) - self.moves.first_state(state) + self.moves.initialize_state(state) cdef int cost cdef const Feature* feats @@ -117,3 +116,4 @@ cdef class GreedyParser: self.model.update(context, guess.clas, best.clas, cost) guess.do(&guess, state) + self.moves.finalize_state(state) diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 58aa90d99..f0eac376a 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -30,7 +30,8 @@ cdef class TransitionSystem: cdef const Transition* c cdef readonly int n_moves - cdef int first_state(self, State* state) except -1 + cdef int initialize_state(self, State* state) except -1 + cdef int finalize_state(self, State* state) except -1 cdef int preprocess_gold(self, GoldParse gold) except -1 diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index e948d483d..0fea8d8c4 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -26,8 +26,11 @@ cdef class TransitionSystem: i += 1 self.c = moves - cdef int first_state(self, State* state) except -1: - raise NotImplementedError + cdef int initialize_state(self, State* state) except -1: + pass + + cdef int finalize_state(self, State* state) except -1: + pass cdef int preprocess_gold(self, GoldParse gold) except -1: raise NotImplementedError