Update oracle tests for Split

This commit is contained in:
Matthew Honnibal 2018-04-03 15:44:58 +02:00
parent e31ef9c7f6
commit f7beefe9c1

View File

@ -10,9 +10,9 @@ from ...syntax.stateclass import StateClass
from ...syntax.arc_eager import ArcEager
def get_sequence_costs(M, words, heads, deps, transitions):
def get_sequence_costs(M, words, gold_words, heads, deps, transitions):
doc = Doc(Vocab(), words=words)
gold = GoldParse(doc, heads=heads, deps=deps)
gold = GoldParse(doc, words=gold_words, heads=heads, deps=deps)
state = StateClass(doc)
M.preprocess_gold(gold)
cost_history = []
@ -21,6 +21,7 @@ def get_sequence_costs(M, words, heads, deps, transitions):
for i in range(M.n_moves):
name = M.class_name(i)
state_costs[name] = M.get_cost(state, gold, i)
M.transition(state, gold_action)
cost_history.append(state_costs)
return state, cost_history
@ -124,7 +125,7 @@ def test_oracle_four_words(arc_eager, vocab):
heads = [1, 1, 3, 3]
deps = ['left', 'ROOT', 'left', 'ROOT']
actions = ['S', 'L-left', 'S', 'B-ROOT', 'S', 'L-left', 'S']
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
state, cost_history = get_sequence_costs(arc_eager, words, words, heads, deps, actions)
assert state.is_final()
for i, state_costs in enumerate(cost_history):
# Check gold moves is 0 cost
@ -138,7 +139,7 @@ def test_non_monotonic_sequence_four_words(arc_eager, vocab):
heads = [1, 1, 3, 3]
deps = ['left', 'B-ROOT', 'left', 'B-ROOT']
actions = ['S', 'R-right', 'R-right', 'L-left', 'L-left', 'L-left', 'S']
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
state, cost_history = get_sequence_costs(arc_eager, words, words, heads, deps, actions)
assert state.is_final()
c0 = cost_history.pop(0)
assert c0['S'] == 0.0
@ -162,7 +163,7 @@ def test_oracle_at_sentence_break(arc_eager, vocab):
heads = [1, 1, 3, 3]
deps = ['left', 'B-ROOT', 'left', 'B-ROOT']
actions = ['S', 'R-right', 'D', 'B-ROOT', 'S']
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
state, cost_history = get_sequence_costs(arc_eager, words, words, heads, deps, actions)
assert not state.is_final(), state.print_state(words)
c0 = cost_history.pop(0)
c1 = cost_history.pop(0)
@ -175,18 +176,31 @@ def test_oracle_at_sentence_break(arc_eager, vocab):
def test_split_oracle(arc_eager, vocab):
if arc_eager.max_split < 2:
return
gold_words = ['a', 'b', 'c']
doc = Doc(vocab, words=['ab', 'c'])
words = ['ab', 'c']
doc = Doc(vocab, words=words)
heads = [2, 2, 2]
deps = ['dep', 'dep', 'ROOT']
actions = ['P-1', 'S', 'L-dep', 'S', 'B-ROOT']
actions = ['P-1', 'S', 'S', 'L-dep', 'L-dep', 'S', 'B-ROOT']
gold = GoldParse(doc, words=gold_words, heads=heads, deps=deps)
assert gold.heads == [[1, 1], 1]
assert gold.labels == [['dep', 'dep'], 'ROOT']
state = StateClass(doc)
M = arc_eager
M.add_action(5, '1')
M.preprocess_gold(gold)
assert gold.fused[0] == 1
state, cost_history = get_sequence_costs(arc_eager, words, gold_words, heads, deps, actions)
assert state.is_final()
for i, state_costs in enumerate(cost_history):
# Check gold moves is 0 cost
assert state_costs[actions[i]] == 0.0, actions[i]
for other_action, cost in state_costs.items():
if other_action != actions[i]:
assert cost >= 1, (i, other_action, actions[i])
annot_tuples = [
(0, 'When', 'WRB', 11, 'advmod', 'O'),