diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 82ac43cb1..a748bc894 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -74,8 +74,6 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog break if BINARY_COSTS and cost >= 1: return cost - #if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0: - # cost += 1 return cost @@ -117,16 +115,12 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil: cdef class Shift: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - if not st.can_push(): - return False - elif st.stack_depth() == 0: # If the stack is empty, we must push - return True - elif st.shifted[st.B(0)]: - return False - elif st.at_break(): - return False + if st.buffer_length == 0: + return 0 + elif st.shifted[st.B(0)] and st.stack_depth() >= 1: + return 0 else: - return True + return 1 @staticmethod cdef int transition(StateC* st, attr_t label) nogil: @@ -156,11 +150,9 @@ cdef class Reduce: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: if st.stack_depth() >= 2: - return True - elif st.at_break() and st.stack_depth() == 1: - return True + return 1 else: - return False + return 0 @staticmethod cdef int transition(StateC* st, attr_t label) nogil: @@ -186,8 +178,6 @@ cdef class Reduce: cost -= 1 if gold.heads[S_i] == st.S(0): cost -= 1 - #if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0: - # cost -= 1 return cost @staticmethod @@ -198,7 +188,10 @@ cdef class Reduce: cdef class LeftArc: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - return st.can_arc() + if st.buffer_length >= 1 and st.stack_depth() >= 1: + return 1 + else: + return 0 @staticmethod cdef int transition(StateC* st, attr_t label) nogil: @@ -233,9 +226,11 @@ cdef class LeftArc: cdef class RightArc: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: - # If there's (perhaps partial) parse pre-set, don't allow cycle. - if not st.can_arc(): + if st.stack_depth() < 1: return 0 + elif st.buffer_length == 0: + return 0 + # If there's (perhaps partial) parse pre-set, don't allow cycle. elif st.H(st.S(0)) == st.B(0): return 0 else: @@ -269,13 +264,16 @@ cdef class Break: cdef bint is_valid(const StateC* st, attr_t label) nogil: cdef int i if not USE_BREAK: - return False + return 0 + if st.stack_depth() != 1: + return 0 else: - return st.can_break() + return 1 @staticmethod cdef int transition(StateC* st, attr_t label) nogil: st.set_break(0) + st._sent[st.S(0)].dep = label st.pop() @staticmethod diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index 3145c5c07..ca056ac42 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -131,7 +131,7 @@ def test_oracle_four_words(arc_eager, vocab): assert state_costs[actions[i]] == 0.0, actions[i] for other_action, cost in state_costs.items(): if other_action != actions[i]: - assert cost >= 1 + assert cost >= 1, (i, other_action, actions[i]) def test_non_monotonic_sequence_four_words(arc_eager, vocab): words = ['a', 'b', 'c', 'd'] @@ -147,24 +147,31 @@ def test_non_monotonic_sequence_four_words(arc_eager, vocab): assert c1['R-right'] != 0.0 c2 = cost_history.pop(0) assert c2['R-right'] != 0.0 - assert c2['B-ROOT'] == 0.0 + assert c2['B-ROOT'] == 9000.0 assert c2['D'] == 0.0 c3 = cost_history.pop(0) assert c3['L-left'] == -1.0 + c4 = cost_history.pop(0) + assert c4['D'] == 0.0 + c5 = cost_history.pop(0) + assert c5['B-ROOT'] == 0.0 -def test_reduce_is_gold_at_break(arc_eager, vocab): +def test_oracle_at_sentence_break(arc_eager, vocab): words = ['a', 'b', 'c', 'd'] heads = [1, 1, 3, 3] deps = ['left', 'B-ROOT', 'left', 'B-ROOT'] - actions = ['S', 'R-right', 'B-ROOT', 'D', 'S', 'L-left', 'S'] + actions = ['S', 'R-right', 'D', 'B-ROOT', 'S'] state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions) - assert state.is_final(), state.print_state(words) + assert not state.is_final(), state.print_state(words) c0 = cost_history.pop(0) c1 = cost_history.pop(0) c2 = cost_history.pop(0) c3 = cost_history.pop(0) - assert c3['D'] == 0.0 + assert c2['D'] == 0.0 + assert c2['B-ROOT'] == 9000.0 + assert c3['B-ROOT'] == 0.0 + assert c3['D'] == 9000.0 annot_tuples = [ (0, 'When', 'WRB', 11, 'advmod', 'O'),