mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Simplify Break transition to require stack depth 1. Hopefully as accurate
This commit is contained in:
parent
a37188fe98
commit
d8dec1134c
|
@ -74,8 +74,6 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog
|
|||
break
|
||||
if BINARY_COSTS and cost >= 1:
|
||||
return cost
|
||||
#if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
||||
# cost += 1
|
||||
return cost
|
||||
|
||||
|
||||
|
@ -117,16 +115,12 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
|||
cdef class Shift:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
if not st.can_push():
|
||||
return False
|
||||
elif st.stack_depth() == 0: # If the stack is empty, we must push
|
||||
return True
|
||||
elif st.shifted[st.B(0)]:
|
||||
return False
|
||||
elif st.at_break():
|
||||
return False
|
||||
if st.buffer_length == 0:
|
||||
return 0
|
||||
elif st.shifted[st.B(0)] and st.stack_depth() >= 1:
|
||||
return 0
|
||||
else:
|
||||
return True
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
|
@ -156,11 +150,9 @@ cdef class Reduce:
|
|||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
if st.stack_depth() >= 2:
|
||||
return True
|
||||
elif st.at_break() and st.stack_depth() == 1:
|
||||
return True
|
||||
return 1
|
||||
else:
|
||||
return False
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
|
@ -186,8 +178,6 @@ cdef class Reduce:
|
|||
cost -= 1
|
||||
if gold.heads[S_i] == st.S(0):
|
||||
cost -= 1
|
||||
#if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
|
||||
# cost -= 1
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
|
@ -198,7 +188,10 @@ cdef class Reduce:
|
|||
cdef class LeftArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
return st.can_arc()
|
||||
if st.buffer_length >= 1 and st.stack_depth() >= 1:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
|
@ -233,9 +226,11 @@ cdef class LeftArc:
|
|||
cdef class RightArc:
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
# If there's (perhaps partial) parse pre-set, don't allow cycle.
|
||||
if not st.can_arc():
|
||||
if st.stack_depth() < 1:
|
||||
return 0
|
||||
elif st.buffer_length == 0:
|
||||
return 0
|
||||
# If there's (perhaps partial) parse pre-set, don't allow cycle.
|
||||
elif st.H(st.S(0)) == st.B(0):
|
||||
return 0
|
||||
else:
|
||||
|
@ -269,13 +264,16 @@ cdef class Break:
|
|||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef int i
|
||||
if not USE_BREAK:
|
||||
return False
|
||||
return 0
|
||||
if st.stack_depth() != 1:
|
||||
return 0
|
||||
else:
|
||||
return st.can_break()
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(StateC* st, attr_t label) nogil:
|
||||
st.set_break(0)
|
||||
st._sent[st.S(0)].dep = label
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -131,7 +131,7 @@ def test_oracle_four_words(arc_eager, vocab):
|
|||
assert state_costs[actions[i]] == 0.0, actions[i]
|
||||
for other_action, cost in state_costs.items():
|
||||
if other_action != actions[i]:
|
||||
assert cost >= 1
|
||||
assert cost >= 1, (i, other_action, actions[i])
|
||||
|
||||
def test_non_monotonic_sequence_four_words(arc_eager, vocab):
|
||||
words = ['a', 'b', 'c', 'd']
|
||||
|
@ -147,24 +147,31 @@ def test_non_monotonic_sequence_four_words(arc_eager, vocab):
|
|||
assert c1['R-right'] != 0.0
|
||||
c2 = cost_history.pop(0)
|
||||
assert c2['R-right'] != 0.0
|
||||
assert c2['B-ROOT'] == 0.0
|
||||
assert c2['B-ROOT'] == 9000.0
|
||||
assert c2['D'] == 0.0
|
||||
c3 = cost_history.pop(0)
|
||||
assert c3['L-left'] == -1.0
|
||||
c4 = cost_history.pop(0)
|
||||
assert c4['D'] == 0.0
|
||||
c5 = cost_history.pop(0)
|
||||
assert c5['B-ROOT'] == 0.0
|
||||
|
||||
|
||||
def test_reduce_is_gold_at_break(arc_eager, vocab):
|
||||
def test_oracle_at_sentence_break(arc_eager, vocab):
|
||||
words = ['a', 'b', 'c', 'd']
|
||||
heads = [1, 1, 3, 3]
|
||||
deps = ['left', 'B-ROOT', 'left', 'B-ROOT']
|
||||
actions = ['S', 'R-right', 'B-ROOT', 'D', 'S', 'L-left', 'S']
|
||||
actions = ['S', 'R-right', 'D', 'B-ROOT', 'S']
|
||||
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
|
||||
assert state.is_final(), state.print_state(words)
|
||||
assert not state.is_final(), state.print_state(words)
|
||||
c0 = cost_history.pop(0)
|
||||
c1 = cost_history.pop(0)
|
||||
c2 = cost_history.pop(0)
|
||||
c3 = cost_history.pop(0)
|
||||
assert c3['D'] == 0.0
|
||||
assert c2['D'] == 0.0
|
||||
assert c2['B-ROOT'] == 9000.0
|
||||
assert c3['B-ROOT'] == 0.0
|
||||
assert c3['D'] == 9000.0
|
||||
|
||||
annot_tuples = [
|
||||
(0, 'When', 'WRB', 11, 'advmod', 'O'),
|
||||
|
|
Loading…
Reference in New Issue
Block a user