From 19ac03ce09d70b8d2b7d8a2b91d820e3b91da0e8 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 1 Apr 2018 14:32:15 +0200 Subject: [PATCH] Go back to letting Break work with deeper stacks It seems very appealing to restrict Break so that it only works when there's one word on the stack. Then we can pop that word, mark it as the root, and continue. However, results are suggesting it's nice to be able to predict Break when the last word of the previous sentence is on the stack, and the first word of the next sentence is at the buffer. This does make sense! Consider that the last word is often a period or something --- a pretty huge clue. We otherwise have to go out of our way to get that feature in. The really decisive thing is we have to handle upcoming sentence breaks anyway, because we need to conform to preset SBD constraints. So, we may as well let the parser predict the Break when it's at a stack/queue position that is most revealing. --- spacy/syntax/arc_eager.pyx | 26 +++++++++++++++++---- spacy/tests/parser/test_arc_eager_oracle.py | 4 ++-- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index d09f541f9..e51b61bdb 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -132,6 +132,8 @@ cdef class Shift: return 0 elif st.shifted[st.B(0)] and st.stack_depth() >= 1: return 0 + elif st.at_break() and st.stack_depth() >= 1: + return 0 else: return 1 @@ -193,6 +195,8 @@ cdef class Reduce: cdef bint is_valid(const StateC* st, attr_t label) nogil: if st.stack_depth() >= 2: return 1 + elif st.at_break() and st.stack_depth() == 1: + return 1 else: return 0 @@ -230,10 +234,14 @@ cdef class Reduce: cdef class LeftArc: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - if st.buffer_length >= 1 and st.stack_depth() >= 1: - return 1 - else: + if st.buffer_length == 0: return 0 + elif st.stack_depth() == 0: + return 0 + elif st.at_break(): + return 0 + else: + return 1 @staticmethod cdef int transition(StateC* st, attr_t label) nogil: @@ -272,6 +280,8 @@ cdef class RightArc: return 0 elif st.buffer_length == 0: return 0 + elif st.at_break(): + return 0 # If there's (perhaps partial) parse pre-set, don't allow cycle. elif st.H(st.S(0)) == st.B(0): return 0 @@ -308,10 +318,16 @@ cdef class RightArc: cdef class Break: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: + # It would seem good to have a stack_depth==1 constraint here. + # That would make the other validities much less complicated. + # However, we need to know about upcoming sentence break to respect + # preset SBD anyway --- so we may as well give the parser the flexibility. cdef int i if not USE_BREAK: return 0 - if st.stack_depth() != 1: + elif st.stack_depth() < 1: + return 0 + elif st._sent[st.B_(0).l_edge].sent_start == -1: return 0 else: return 1 @@ -319,7 +335,6 @@ cdef class Break: @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.set_break(0) - st._sent[st.S(0)].dep = label st.pop() @staticmethod @@ -661,6 +676,7 @@ cdef class ArcEager(TransitionSystem): print(gold.heads) print(gold.labels) print(gold.sent_starts) + print(stcls.history) raise ValueError( "Could not find a gold-standard action to supervise the" "dependency parser. The GoldParse was projective. The " diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index ca056ac42..bffe9621b 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -147,7 +147,7 @@ def test_non_monotonic_sequence_four_words(arc_eager, vocab): assert c1['R-right'] != 0.0 c2 = cost_history.pop(0) assert c2['R-right'] != 0.0 - assert c2['B-ROOT'] == 9000.0 + assert c2['B-ROOT'] == 0.0 assert c2['D'] == 0.0 c3 = cost_history.pop(0) assert c3['L-left'] == -1.0 @@ -169,7 +169,7 @@ def test_oracle_at_sentence_break(arc_eager, vocab): c2 = cost_history.pop(0) c3 = cost_history.pop(0) assert c2['D'] == 0.0 - assert c2['B-ROOT'] == 9000.0 + assert c2['B-ROOT'] == 0.0 assert c3['B-ROOT'] == 0.0 assert c3['D'] == 9000.0