Add better arc-eager oracle tests

This commit is contained in:
Matthew Honnibal 2018-04-01 10:41:52 +02:00
parent 697bcaa34f
commit adbb1f7533

View File

@ -1,9 +1,170 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import pytest
from ...vocab import Vocab from ...vocab import Vocab
from ...pipeline import DependencyParser from ...pipeline import DependencyParser
from ...tokens import Doc from ...tokens import Doc
from ...gold import GoldParse from ...gold import GoldParse
from ...syntax.nonproj import projectivize from ...syntax.nonproj import projectivize
from ...syntax.stateclass import StateClass
from ...syntax.arc_eager import ArcEager
def get_sequence_costs(M, words, heads, deps, transitions):
doc = Doc(Vocab(), words=words)
gold = GoldParse(doc, heads=heads, deps=deps)
state = StateClass(doc)
M.preprocess_gold(gold)
cost_history = []
for gold_action in transitions:
state_costs = {}
for i in range(M.n_moves):
name = M.class_name(i)
state_costs[name] = M.get_cost(state, gold, i)
M.transition(state, gold_action)
cost_history.append(state_costs)
return state, cost_history
@pytest.fixture
def vocab():
return Vocab()
@pytest.fixture
def arc_eager(vocab):
moves = ArcEager(vocab.strings, ArcEager.get_actions())
moves.add_action(2, 'left')
moves.add_action(3, 'right')
return moves
@pytest.fixture
def words():
return ['a', 'b']
@pytest.fixture
def doc(words, vocab):
if vocab is None:
vocab = Vocab()
return Doc(vocab, words=list(words))
@pytest.fixture
def gold(doc, words):
if len(words) == 2:
return GoldParse(doc, words=['a', 'b'], heads=[0, 0], deps=['ROOT', 'right'])
else:
raise NotImplementedError
def test_shift_is_gold_at_first_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
assert arc_eager.get_cost(state, gold, 'S') == 0
def test_reduce_is_not_gold_at_second_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
assert arc_eager.get_cost(state, gold, 'D') != 0
def test_break_is_not_gold_at_second_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
assert arc_eager.get_cost(state, gold, 'B-ROOT') != 0
def test_left_is_not_gold_at_second_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
assert arc_eager.get_cost(state, gold, 'L-left') != 0
def test_right_is_gold_at_second_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
assert arc_eager.get_cost(state, gold, 'R-right') == 0
def test_reduce_is_gold_at_third_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
arc_eager.transition(state, 'R-right')
assert arc_eager.get_cost(state, gold, 'D') == 0
def test_cant_arc_is_gold_at_third_state(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
arc_eager.transition(state, 'R-right')
assert not state.can_arc()
def test_fourth_state_is_final(arc_eager, doc, gold):
state = StateClass(doc)
arc_eager.preprocess_gold(gold)
arc_eager.transition(state, 'S')
arc_eager.transition(state, 'R-right')
arc_eager.transition(state, 'D')
assert state.is_final()
def test_oracle_sequence_two_words(arc_eager, doc, gold):
parser = DependencyParser(doc.vocab, moves=arc_eager)
state = StateClass(doc)
parser.moves.preprocess_gold(gold)
actions = parser.moves.get_oracle_sequence(doc, gold)
names = [parser.moves.class_name(i) for i in actions]
assert names == ['S', 'R-right', 'D']
def test_oracle_four_words(arc_eager, vocab):
words = ['a', 'b', 'c', 'd']
heads = [1, 1, 3, 3]
deps = ['left', 'ROOT', 'left', 'ROOT']
actions = ['S', 'L-left', 'S', 'B-ROOT', 'S', 'L-left', 'S']
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
assert state.is_final()
for i, state_costs in enumerate(cost_history):
# Check gold moves is 0 cost
assert state_costs[actions[i]] == 0.0, actions[i]
for other_action, cost in state_costs.items():
if other_action != actions[i]:
assert cost >= 1
def test_non_monotonic_sequence_four_words(arc_eager, vocab):
words = ['a', 'b', 'c', 'd']
heads = [1, 1, 3, 3]
deps = ['left', 'B-ROOT', 'left', 'B-ROOT']
actions = ['S', 'R-right', 'R-right', 'L-left', 'L-left', 'L-left', 'S']
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
assert state.is_final()
c0 = cost_history.pop(0)
assert c0['S'] == 0.0
c1 = cost_history.pop(0)
assert c1['L-left'] == 0.0
assert c1['R-right'] != 0.0
c2 = cost_history.pop(0)
assert c2['R-right'] != 0.0
assert c2['B-ROOT'] == 0.0
assert c2['D'] == 0.0
c3 = cost_history.pop(0)
assert c3['L-left'] == -1.0
def test_reduce_is_gold_at_break(arc_eager, vocab):
words = ['a', 'b', 'c', 'd']
heads = [1, 1, 3, 3]
deps = ['left', 'B-ROOT', 'left', 'B-ROOT']
actions = ['S', 'R-right', 'B-ROOT', 'D', 'S', 'L-left', 'S']
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
assert state.is_final(), state.print_state(words)
c0 = cost_history.pop(0)
c1 = cost_history.pop(0)
c2 = cost_history.pop(0)
c3 = cost_history.pop(0)
assert c3['D'] == 0.0
annot_tuples = [ annot_tuples = [
(0, 'When', 'WRB', 11, 'advmod', 'O'), (0, 'When', 'WRB', 11, 'advmod', 'O'),