diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index feb1fee13..2c6771ff2 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -35,14 +35,15 @@ cdef get_cost_func_t[N_MOVES] get_cost_funcs cdef class ArcEager(TransitionSystem): @classmethod def get_labels(cls, gold_parses): - labels = {SHIFT: {0: True}, REDUCE: {0: True}, RIGHT: {0: True}, - LEFT: {0: True}, BREAK: {0: True}} + labels = {SHIFT: {'ROOT': True}, REDUCE: {'ROOT': True}, RIGHT: {}, + LEFT: {}, BREAK: {'ROOT': True}} for parse in gold_parses: for i, (head, label) in enumerate(zip(parse.heads, parse.labels)): - if head > i: - labels[RIGHT][label] = True - else: - labels[LEFT][label] = True + if label != 'ROOT': + if head > i: + labels[RIGHT][label] = True + elif head < i: + labels[LEFT][label] = True return labels cdef Transition init_transition(self, int clas, int move, int label) except *: @@ -71,6 +72,8 @@ cdef class ArcEager(TransitionSystem): if scores[i] > score and is_valid[self.c[i].move]: best = self.c[i] score = scores[i] + assert best.clas < self.n_moves + assert score > MIN_SCORE # Label Shift moves with the best Right-Arc label, for non-monotonic # actions if best.move == SHIFT: @@ -85,7 +88,7 @@ cdef class ArcEager(TransitionSystem): cdef int _do_shift(const Transition* self, State* state) except -1: # Set the dep label, in case we need it after we reduce if NON_MONOTONIC: - get_s0(state).dep = self.label + state.sent[state.i].dep = self.label push_stack(state) @@ -124,7 +127,8 @@ do_funcs[BREAK] = _do_break cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1: - assert not at_eol(s) + if not _can_shift(s): + return 9000 cost = 0 cost += head_in_stack(s, s.i, gold.c_heads) cost += children_in_stack(s, s.i, gold.c_heads) @@ -137,7 +141,8 @@ cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) exc cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1: - assert s.stack_len >= 1 + if not _can_right(s): + return 9000 cost = 0 if gold.c_heads[s.i] == s.stack[0]: cost += self.label != gold.c_labels[s.i] @@ -151,7 +156,8 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) exc cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1: - assert s.stack_len >= 1 + if not _can_left(s): + return 9000 cost = 0 if gold.c_heads[s.stack[0]] == s.i: cost += self.label != gold.c_labels[s.stack[0]] @@ -166,6 +172,8 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) exce cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1: + if not _can_reduce(s): + return 9000 cdef int cost = 0 cost += children_in_buffer(s, s.stack[0], gold.c_heads) if NON_MONOTONIC: @@ -174,6 +182,8 @@ cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) ex cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1: + if not _can_break(s): + return 9000 # When we break, we Reduce all of the words on the stack. cdef int cost = 0 # Number of deps between S0...Sn and N0...Nn