diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 409676c55..d223f3968 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -1,3 +1,6 @@ +# cython: profile=True +# cython: cdivision=True +# cython: infer_types=True from __future__ import unicode_literals import ctypes @@ -24,6 +27,7 @@ from .nonproj import PseudoProjectivity DEF NON_MONOTONIC = True DEF USE_BREAK = True + cdef weight_t MIN_SCORE = -90000 # Break transition from here @@ -65,10 +69,12 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: cdef weight_t cost = 0 cdef int i, B_i + # Count number of words in buffer with deendencies to/from the target. for i in range(stcls.buffer_length()): B_i = stcls.B(i) cost += gold.heads[B_i] == target cost += gold.heads[target] == B_i + # TODO: Should re-examine this for German --- it assumes projectivity. if gold.heads[B_i] == B_i or gold.heads[B_i] < target: break if Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0: @@ -154,7 +160,18 @@ cdef class Reduce: @staticmethod cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil: - return pop_cost(st, gold, st.S(0)) + cost = pop_cost(st, gold, st.S(0)) + if not st.has_head(st.S(0)): + # Decrement cost for the arcs we save + for i in range(1, st.stack_depth()): + S_i = st.S(i) + if gold.heads[st.S(0)] == S_i: + cost -= 1 + if gold.heads[S_i] == st.S(0): + cost -= 1 + if Break.is_valid(st.c, -1) and Break.move_cost(st, gold) == 0: + cost -= 1 + return cost @staticmethod cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil: @@ -180,7 +197,8 @@ cdef class LeftArc: cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: cdef weight_t cost = 0 if arc_is_gold(gold, s.B(0), s.S(0)): - return 0 + # Have a negative cost if we 'recover' from the wrong dependency + return 0 if not s.has_head(s.S(0)) else -1 else: # Account for deps we might lose between S0 and stack if not s.has_head(s.S(0)): @@ -407,7 +425,7 @@ cdef class ArcEager(TransitionSystem): cdef move_cost_func_t[N_MOVES] move_cost_funcs cdef weight_t[N_MOVES] move_costs for i in range(N_MOVES): - move_costs[i] = -1 + move_costs[i] = 9000 move_cost_funcs[SHIFT] = Shift.move_cost move_cost_funcs[REDUCE] = Reduce.move_cost move_cost_funcs[LEFT] = LeftArc.move_cost @@ -429,10 +447,10 @@ cdef class ArcEager(TransitionSystem): is_valid[i] = True move = self.c[i].move label = self.c[i].label - if move_costs[move] == -1: + if move_costs[move] == 9000: move_costs[move] = move_cost_funcs[move](stcls, &gold.c) costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label) - n_gold += costs[i] == 0 + n_gold += costs[i] <= 0 else: is_valid[i] = False costs[i] = 9000