diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index a748bc894..4db9b1c18 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -35,6 +35,8 @@ cdef enum: BREAK + SPLIT + N_MOVES @@ -44,6 +46,7 @@ MOVE_NAMES[REDUCE] = 'D' MOVE_NAMES[LEFT] = 'L' MOVE_NAMES[RIGHT] = 'R' MOVE_NAMES[BREAK] = 'B' +MOVE_NAMES[SPLIT] = 'P' # Helper functions for the arc-eager oracle @@ -60,6 +63,10 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no if BINARY_COSTS and cost >= 1: return cost cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0 + # If the token wasn't split before, but gold says it *should* be split, + # don't push (split instead) + if USE_SPLIT and not stcls.c.was_split[stcls.c.B(0)]: + cost += gold.fused[stcls.c.B(0)] return cost @@ -112,6 +119,7 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t labe cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil: return gold.heads[word] == word or not gold.has_dep[word] + cdef class Shift: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: @@ -124,8 +132,6 @@ cdef class Shift: @staticmethod cdef int transition(StateC* st, attr_t label) nogil: - #if label != 0: - # st.split(st.B(1), label) st.shifted[st.B(0)] = 1 st.push() @@ -140,10 +146,41 @@ cdef class Shift: @staticmethod cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: return 0 - #if gold.fused_tokens[s.B(1)] == label: TODO - # return 0 - #else: - # return 1 + + +cdef class Split: + @staticmethod + cdef bint is_valid(const StateC* st, attr_t label) nogil: + if not USE_SPLIT: + return 0 + elif st.buffer_length == 0: + return 0 + elif st.is_split[st.B(0)]: + return 0 + else: + return 1 + + @staticmethod + cdef int transition(StateC* st, attr_t label) nogil: + st.split(0, label) + + @staticmethod + cdef weight_t cost(StateClass st, const GoldParseC* gold, attr_t label) nogil: + return Split.move_cost(st, gold) + Split.label_cost(st, gold, label) + + @staticmethod + cdef weight_t move_cost(StateClass st, const GoldParseC* gold) nogil: + if gold.fused[st.B(0)]: + return 0 + else: + return 1 + + @staticmethod + cdef weight_t label_cost(StateClass st, const GoldParseC* gold, attr_t label) nogil: + if gold.fused[st.B(0)] == label: + return 0 + else: + return 1 cdef class Reduce: @@ -247,7 +284,11 @@ cdef class RightArc: @staticmethod cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: - if arc_is_gold(gold, s.S(0), s.B(0)): + # If the token wasn't split before, but gold says it *should* be split, + # don't right-arc (split instead) + if not s.c.was_split[s.c.B(0)] and gold.fused[s.c.B(0)]: + return gold.fused[s.c.B(0)] + elif arc_is_gold(gold, s.S(0), s.B(0)): return 0 elif s.c.shifted[s.B(0)]: return push_cost(s, gold, s.B(0)) @@ -374,7 +415,7 @@ cdef class ArcEager(TransitionSystem): property action_types: def __get__(self): - return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) + return (SHIFT, REDUCE, LEFT, RIGHT, BREAK, SPLIT) def get_cost(self, StateClass state, GoldParse gold, action): cdef Transition t = self.lookup_transition(action) @@ -513,6 +554,10 @@ cdef class ArcEager(TransitionSystem): t.is_valid = Break.is_valid t.do = Break.transition t.get_cost = Break.cost + elif move == SPLIT: + t.is_valid = Split.is_valid + t.do = Split.transition + t.get_cost = Split.cost else: raise Exception(move) return t @@ -543,6 +588,7 @@ cdef class ArcEager(TransitionSystem): is_valid[LEFT] = LeftArc.is_valid(st, 0) is_valid[RIGHT] = RightArc.is_valid(st, 0) is_valid[BREAK] = Break.is_valid(st, 0) + is_valid[SPLIT] = Split.is_valid(st, 0) cdef int i for i in range(self.n_moves): output[i] = is_valid[self.c[i].move] @@ -561,12 +607,14 @@ cdef class ArcEager(TransitionSystem): move_cost_funcs[LEFT] = LeftArc.move_cost move_cost_funcs[RIGHT] = RightArc.move_cost move_cost_funcs[BREAK] = Break.move_cost + move_cost_funcs[SPLIT] = Split.move_cost label_cost_funcs[SHIFT] = Shift.label_cost label_cost_funcs[REDUCE] = Reduce.label_cost label_cost_funcs[LEFT] = LeftArc.label_cost label_cost_funcs[RIGHT] = RightArc.label_cost label_cost_funcs[BREAK] = Break.label_cost + label_cost_funcs[SPLIT] = Split.label_cost cdef attr_t* labels = gold.c.labels cdef int* heads = gold.c.heads