diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index d09f541f9..e51b61bdb 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -132,6 +132,8 @@ cdef class Shift: return 0 elif st.shifted[st.B(0)] and st.stack_depth() >= 1: return 0 + elif st.at_break() and st.stack_depth() >= 1: + return 0 else: return 1 @@ -193,6 +195,8 @@ cdef class Reduce: cdef bint is_valid(const StateC* st, attr_t label) nogil: if st.stack_depth() >= 2: return 1 + elif st.at_break() and st.stack_depth() == 1: + return 1 else: return 0 @@ -230,10 +234,14 @@ cdef class Reduce: cdef class LeftArc: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - if st.buffer_length >= 1 and st.stack_depth() >= 1: - return 1 - else: + if st.buffer_length == 0: return 0 + elif st.stack_depth() == 0: + return 0 + elif st.at_break(): + return 0 + else: + return 1 @staticmethod cdef int transition(StateC* st, attr_t label) nogil: @@ -272,6 +280,8 @@ cdef class RightArc: return 0 elif st.buffer_length == 0: return 0 + elif st.at_break(): + return 0 # If there's (perhaps partial) parse pre-set, don't allow cycle. elif st.H(st.S(0)) == st.B(0): return 0 @@ -308,10 +318,16 @@ cdef class RightArc: cdef class Break: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: + # It would seem good to have a stack_depth==1 constraint here. + # That would make the other validities much less complicated. + # However, we need to know about upcoming sentence break to respect + # preset SBD anyway --- so we may as well give the parser the flexibility. cdef int i if not USE_BREAK: return 0 - if st.stack_depth() != 1: + elif st.stack_depth() < 1: + return 0 + elif st._sent[st.B_(0).l_edge].sent_start == -1: return 0 else: return 1 @@ -319,7 +335,6 @@ cdef class Break: @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.set_break(0) - st._sent[st.S(0)].dep = label st.pop() @staticmethod @@ -661,6 +676,7 @@ cdef class ArcEager(TransitionSystem): print(gold.heads) print(gold.labels) print(gold.sent_starts) + print(stcls.history) raise ValueError( "Could not find a gold-standard action to supervise the" "dependency parser. The GoldParse was projective. The " diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index ca056ac42..bffe9621b 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -147,7 +147,7 @@ def test_non_monotonic_sequence_four_words(arc_eager, vocab): assert c1['R-right'] != 0.0 c2 = cost_history.pop(0) assert c2['R-right'] != 0.0 - assert c2['B-ROOT'] == 9000.0 + assert c2['B-ROOT'] == 0.0 assert c2['D'] == 0.0 c3 = cost_history.pop(0) assert c3['L-left'] == -1.0 @@ -169,7 +169,7 @@ def test_oracle_at_sentence_break(arc_eager, vocab): c2 = cost_history.pop(0) c3 = cost_history.pop(0) assert c2['D'] == 0.0 - assert c2['B-ROOT'] == 9000.0 + assert c2['B-ROOT'] == 0.0 assert c3['B-ROOT'] == 0.0 assert c3['D'] == 9000.0