Simplify Break transition to require stack depth 1. Hopefully as accurate

This commit is contained in:
Matthew Honnibal 2018-04-01 12:53:25 +02:00
parent a37188fe98
commit d8dec1134c
2 changed files with 33 additions and 28 deletions

View File

@ -74,8 +74,6 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog
break
if BINARY_COSTS and cost >= 1:
return cost
#if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
# cost += 1
return cost
@ -117,16 +115,12 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
cdef class Shift:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if not st.can_push():
return False
elif st.stack_depth() == 0: # If the stack is empty, we must push
return True
elif st.shifted[st.B(0)]:
return False
elif st.at_break():
return False
if st.buffer_length == 0:
return 0
elif st.shifted[st.B(0)] and st.stack_depth() >= 1:
return 0
else:
return True
return 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
@ -156,11 +150,9 @@ cdef class Reduce:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
if st.stack_depth() >= 2:
return True
elif st.at_break() and st.stack_depth() == 1:
return True
return 1
else:
return False
return 0
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
@ -186,8 +178,6 @@ cdef class Reduce:
cost -= 1
if gold.heads[S_i] == st.S(0):
cost -= 1
#if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
# cost -= 1
return cost
@staticmethod
@ -198,7 +188,10 @@ cdef class Reduce:
cdef class LeftArc:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
return st.can_arc()
if st.buffer_length >= 1 and st.stack_depth() >= 1:
return 1
else:
return 0
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
@ -233,9 +226,11 @@ cdef class LeftArc:
cdef class RightArc:
@staticmethod
cdef bint is_valid(const StateC* st, attr_t label) nogil:
# If there's (perhaps partial) parse pre-set, don't allow cycle.
if not st.can_arc():
if st.stack_depth() < 1:
return 0
elif st.buffer_length == 0:
return 0
# If there's (perhaps partial) parse pre-set, don't allow cycle.
elif st.H(st.S(0)) == st.B(0):
return 0
else:
@ -269,13 +264,16 @@ cdef class Break:
cdef bint is_valid(const StateC* st, attr_t label) nogil:
cdef int i
if not USE_BREAK:
return False
return 0
if st.stack_depth() != 1:
return 0
else:
return st.can_break()
return 1
@staticmethod
cdef int transition(StateC* st, attr_t label) nogil:
st.set_break(0)
st._sent[st.S(0)].dep = label
st.pop()
@staticmethod

View File

@ -131,7 +131,7 @@ def test_oracle_four_words(arc_eager, vocab):
assert state_costs[actions[i]] == 0.0, actions[i]
for other_action, cost in state_costs.items():
if other_action != actions[i]:
assert cost >= 1
assert cost >= 1, (i, other_action, actions[i])
def test_non_monotonic_sequence_four_words(arc_eager, vocab):
words = ['a', 'b', 'c', 'd']
@ -147,24 +147,31 @@ def test_non_monotonic_sequence_four_words(arc_eager, vocab):
assert c1['R-right'] != 0.0
c2 = cost_history.pop(0)
assert c2['R-right'] != 0.0
assert c2['B-ROOT'] == 0.0
assert c2['B-ROOT'] == 9000.0
assert c2['D'] == 0.0
c3 = cost_history.pop(0)
assert c3['L-left'] == -1.0
c4 = cost_history.pop(0)
assert c4['D'] == 0.0
c5 = cost_history.pop(0)
assert c5['B-ROOT'] == 0.0
def test_reduce_is_gold_at_break(arc_eager, vocab):
def test_oracle_at_sentence_break(arc_eager, vocab):
words = ['a', 'b', 'c', 'd']
heads = [1, 1, 3, 3]
deps = ['left', 'B-ROOT', 'left', 'B-ROOT']
actions = ['S', 'R-right', 'B-ROOT', 'D', 'S', 'L-left', 'S']
actions = ['S', 'R-right', 'D', 'B-ROOT', 'S']
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
assert state.is_final(), state.print_state(words)
assert not state.is_final(), state.print_state(words)
c0 = cost_history.pop(0)
c1 = cost_history.pop(0)
c2 = cost_history.pop(0)
c3 = cost_history.pop(0)
assert c3['D'] == 0.0
assert c2['D'] == 0.0
assert c2['B-ROOT'] == 9000.0
assert c3['B-ROOT'] == 0.0
assert c3['D'] == 9000.0
annot_tuples = [
(0, 'When', 'WRB', 11, 'advmod', 'O'),