diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index afa05bd9a..f1dbcf426 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -1,6 +1,9 @@ # cython: profile=True from __future__ import unicode_literals +import ctypes +import os + from ._state cimport State from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep @@ -15,6 +18,12 @@ from .transition_system cimport move_cost_func_t, label_cost_func_t from ..gold cimport GoldParse from ..gold cimport GoldParseC +from libc.stdint cimport uint32_t +from libc.string cimport memcpy + +from cymem.cymem cimport Pool +from ..stateclass cimport StateClass + DEF NON_MONOTONIC = True DEF USE_BREAK = True @@ -78,7 +87,6 @@ cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) return 0 - cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1: if gold.labels[child] == -1: return True @@ -110,6 +118,11 @@ cdef class Shift: cdef bint is_valid(const State* s, int label) except -1: return not at_eol(s) + @staticmethod + cdef bint _new_is_valid(StateClass st, int label) except -1: + return not st.eol() + + @staticmethod cdef int transition(State* state, int label) except -1: # Set the dep label, in case we need it after we reduce @@ -133,6 +146,13 @@ cdef class Shift: cdef class Reduce: + @staticmethod + cdef bint _new_is_valid(StateClass st, int label) except -1: + if NON_MONOTONIC: + return st.stack_depth() >= 2 #and not missing_brackets(s) + else: + return st.stack_depth() >= 2 and st.has_head(st.S(0)) + @staticmethod cdef bint is_valid(const State* s, int label) except -1: if NON_MONOTONIC: @@ -165,6 +185,13 @@ cdef class Reduce: cdef class LeftArc: + @staticmethod + cdef bint _new_is_valid(StateClass st, int label) except -1: + if NON_MONOTONIC: + return st.stack_depth() >= 1 #and not missing_brackets(s) + else: + return st.stack_depth() >= 1 and not st.has_head(st.S(0)) + @staticmethod cdef bint is_valid(const State* s, int label) except -1: if NON_MONOTONIC: @@ -206,6 +233,10 @@ cdef class RightArc: cdef bint is_valid(const State* s, int label) except -1: return s.stack_len >= 1 and not at_eol(s) + @staticmethod + cdef bint _new_is_valid(StateClass st, int label) except -1: + return st.stack_depth() >= 1 and not st.eol() + @staticmethod cdef int transition(State* state, int label) except -1: add_dep(state, state.stack[0], state.i, label) @@ -230,6 +261,32 @@ cdef class RightArc: cdef class Break: + @staticmethod + cdef bint _new_is_valid(StateClass st, int label) except -1: + cdef int i + if not USE_BREAK: + return False + elif st.eol(): + return False + elif st.stack_depth() < 1: + return False + elif NON_MONOTONIC: + return True + else: + # In the Break transition paper, they have this constraint that prevents + # Break if stack is disconnected. But, if we're doing non-monotonic parsing, + # we prefer to relax this constraint. This is helpful in parsing whole + # documents, because then we don't get stuck with words on the stack. + seen_headless = False + for i in range(st.stack_depth()): + if not st.has_head(st.S(i)): + if seen_headless: + return False + else: + seen_headless = True + # TODO: Constituency constraints + return True + @staticmethod cdef bint is_valid(const State* s, int label) except -1: cdef int i @@ -584,14 +641,17 @@ cdef class ArcEager(TransitionSystem): output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label) cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: + cdef Pool mem = Pool() + cdef StateClass stcls = StateClass.from_struct(mem, s) cdef bint[N_MOVES] is_valid - is_valid[SHIFT] = Shift.is_valid(s, -1) - is_valid[REDUCE] = Reduce.is_valid(s, -1) - is_valid[LEFT] = LeftArc.is_valid(s, -1) - is_valid[RIGHT] = RightArc.is_valid(s, -1) - is_valid[BREAK] = Break.is_valid(s, -1) - is_valid[CONSTITUENT] = Constituent.is_valid(s, -1) - is_valid[ADJUST] = Adjust.is_valid(s, -1) + #is_valid[SHIFT] = Shift.is_valid(s, -1) + is_valid[SHIFT] = Shift._new_is_valid(stcls, -1) + is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1) + is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1) + is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1) + is_valid[BREAK] = Break._new_is_valid(stcls, -1) + is_valid[CONSTITUENT] = False # Constituent._new_is_valid(s, -1) + is_valid[ADJUST] = False # Adjust._new_is_valid(s, -1) cdef Transition best cdef weight_t score = MIN_SCORE cdef int i