diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 08eb23d1c..ba6e3af04 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -1,3 +1,6 @@ +# cython: profile=True +# cython: cdivision=True +# cython: infer_types=True from __future__ import unicode_literals import ctypes @@ -155,7 +158,18 @@ cdef class Reduce: @staticmethod cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil: - return pop_cost(st, gold, st.S(0)) + cost = pop_cost(st, gold, st.S(0)) + if not st.has_head(st.S(0)): + # Decrement cost for the arcs e save + for i in range(1, st.stack_depth()): + S_i = st.S(i) + if gold.heads[st.S(0)] == S_i: + cost -= 1 + if gold.heads[S_i] == st.S(0): + cost -= 1 + if Break.is_valid(st.c, -1) and Break.move_cost(st, gold) == 0: + cost -= 1 + return cost @staticmethod cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil: @@ -181,7 +195,8 @@ cdef class LeftArc: cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: cdef weight_t cost = 0 if arc_is_gold(gold, s.B(0), s.S(0)): - return 0 + # Have a negative cost if we 'recover' from the wrong dependency + return 0 if not s.has_head(s.S(0)) else -1 else: # Account for deps we might lose between S0 and stack if not s.has_head(s.S(0)): @@ -281,7 +296,7 @@ cdef int _get_root(int word, const GoldParseC* gold) nogil: cdef class ArcEager(TransitionSystem): @classmethod def get_actions(cls, **kwargs): - actions = kwargs.get('actions', + actions = kwargs.get('actions', { SHIFT: {'': True}, REDUCE: {'': True}, @@ -294,7 +309,7 @@ cdef class ArcEager(TransitionSystem): for label in kwargs.get('right_labels', []): if label.upper() != 'ROOT': actions[RIGHT][label] = True - + for raw_text, sents in kwargs.get('gold_parses', []): for (ids, words, tags, heads, labels, iob), ctnts in sents: for child, head, label in zip(ids, heads, labels): @@ -407,14 +422,14 @@ cdef class ArcEager(TransitionSystem): for i in range(self.n_moves): output[i] = is_valid[self.c[i].move] - cdef int set_costs(self, int* is_valid, weight_t* costs, + cdef int set_costs(self, int* is_valid, weight_t* costs, StateClass stcls, GoldParse gold) except -1: cdef int i, move, label cdef label_cost_func_t[N_MOVES] label_cost_funcs cdef move_cost_func_t[N_MOVES] move_cost_funcs cdef weight_t[N_MOVES] move_costs for i in range(N_MOVES): - move_costs[i] = -1 + move_costs[i] = 9000 move_cost_funcs[SHIFT] = Shift.move_cost move_cost_funcs[REDUCE] = Reduce.move_cost move_cost_funcs[LEFT] = LeftArc.move_cost @@ -436,14 +451,14 @@ cdef class ArcEager(TransitionSystem): is_valid[i] = True move = self.c[i].move label = self.c[i].label - if move_costs[move] == -1: + if move_costs[move] == 9000: move_costs[move] = move_cost_funcs[move](stcls, &gold.c) costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label) n_gold += costs[i] <= 0 else: is_valid[i] = False costs[i] = 9000 - if n_gold == 0: + if n_gold < 1: # Check projectivity --- leading cause if is_nonproj_tree(gold.heads): raise ValueError( @@ -463,7 +478,7 @@ cdef class ArcEager(TransitionSystem): "Could not find a gold-standard action to supervise the dependency " "parser.\n" "The GoldParse was projective.\n" - "The transition system has %d actions.\n" + "The transition system has %d actions.\n" "State at failure:\n" "%s" % (self.n_moves, stcls.print_state(gold.words))) assert n_gold >= 1