From adbb1f7533855f0ba27896edd58b3135da2c018d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 1 Apr 2018 10:41:52 +0200 Subject: [PATCH] Add better arc-eager oracle tests --- spacy/tests/parser/test_arc_eager_oracle.py | 161 ++++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index 5f3a553e2..3145c5c07 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -1,9 +1,170 @@ from __future__ import unicode_literals +import pytest + from ...vocab import Vocab from ...pipeline import DependencyParser from ...tokens import Doc from ...gold import GoldParse from ...syntax.nonproj import projectivize +from ...syntax.stateclass import StateClass +from ...syntax.arc_eager import ArcEager + + +def get_sequence_costs(M, words, heads, deps, transitions): + doc = Doc(Vocab(), words=words) + gold = GoldParse(doc, heads=heads, deps=deps) + state = StateClass(doc) + M.preprocess_gold(gold) + cost_history = [] + for gold_action in transitions: + state_costs = {} + for i in range(M.n_moves): + name = M.class_name(i) + state_costs[name] = M.get_cost(state, gold, i) + M.transition(state, gold_action) + cost_history.append(state_costs) + return state, cost_history + + +@pytest.fixture +def vocab(): + return Vocab() + +@pytest.fixture +def arc_eager(vocab): + moves = ArcEager(vocab.strings, ArcEager.get_actions()) + moves.add_action(2, 'left') + moves.add_action(3, 'right') + return moves + +@pytest.fixture +def words(): + return ['a', 'b'] + +@pytest.fixture +def doc(words, vocab): + if vocab is None: + vocab = Vocab() + return Doc(vocab, words=list(words)) + +@pytest.fixture +def gold(doc, words): + if len(words) == 2: + return GoldParse(doc, words=['a', 'b'], heads=[0, 0], deps=['ROOT', 'right']) + else: + raise NotImplementedError + + +def test_shift_is_gold_at_first_state(arc_eager, doc, gold): + state = StateClass(doc) + arc_eager.preprocess_gold(gold) + assert arc_eager.get_cost(state, gold, 'S') == 0 + + +def test_reduce_is_not_gold_at_second_state(arc_eager, doc, gold): + state = StateClass(doc) + arc_eager.preprocess_gold(gold) + arc_eager.transition(state, 'S') + assert arc_eager.get_cost(state, gold, 'D') != 0 + + +def test_break_is_not_gold_at_second_state(arc_eager, doc, gold): + state = StateClass(doc) + arc_eager.preprocess_gold(gold) + arc_eager.transition(state, 'S') + assert arc_eager.get_cost(state, gold, 'B-ROOT') != 0 + +def test_left_is_not_gold_at_second_state(arc_eager, doc, gold): + state = StateClass(doc) + arc_eager.preprocess_gold(gold) + arc_eager.transition(state, 'S') + assert arc_eager.get_cost(state, gold, 'L-left') != 0 + +def test_right_is_gold_at_second_state(arc_eager, doc, gold): + state = StateClass(doc) + arc_eager.preprocess_gold(gold) + arc_eager.transition(state, 'S') + assert arc_eager.get_cost(state, gold, 'R-right') == 0 + + +def test_reduce_is_gold_at_third_state(arc_eager, doc, gold): + state = StateClass(doc) + arc_eager.preprocess_gold(gold) + arc_eager.transition(state, 'S') + arc_eager.transition(state, 'R-right') + assert arc_eager.get_cost(state, gold, 'D') == 0 + +def test_cant_arc_is_gold_at_third_state(arc_eager, doc, gold): + state = StateClass(doc) + arc_eager.preprocess_gold(gold) + arc_eager.transition(state, 'S') + arc_eager.transition(state, 'R-right') + assert not state.can_arc() + + +def test_fourth_state_is_final(arc_eager, doc, gold): + state = StateClass(doc) + arc_eager.preprocess_gold(gold) + arc_eager.transition(state, 'S') + arc_eager.transition(state, 'R-right') + arc_eager.transition(state, 'D') + assert state.is_final() + + +def test_oracle_sequence_two_words(arc_eager, doc, gold): + parser = DependencyParser(doc.vocab, moves=arc_eager) + state = StateClass(doc) + parser.moves.preprocess_gold(gold) + actions = parser.moves.get_oracle_sequence(doc, gold) + names = [parser.moves.class_name(i) for i in actions] + assert names == ['S', 'R-right', 'D'] + +def test_oracle_four_words(arc_eager, vocab): + words = ['a', 'b', 'c', 'd'] + heads = [1, 1, 3, 3] + deps = ['left', 'ROOT', 'left', 'ROOT'] + actions = ['S', 'L-left', 'S', 'B-ROOT', 'S', 'L-left', 'S'] + state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions) + assert state.is_final() + for i, state_costs in enumerate(cost_history): + # Check gold moves is 0 cost + assert state_costs[actions[i]] == 0.0, actions[i] + for other_action, cost in state_costs.items(): + if other_action != actions[i]: + assert cost >= 1 + +def test_non_monotonic_sequence_four_words(arc_eager, vocab): + words = ['a', 'b', 'c', 'd'] + heads = [1, 1, 3, 3] + deps = ['left', 'B-ROOT', 'left', 'B-ROOT'] + actions = ['S', 'R-right', 'R-right', 'L-left', 'L-left', 'L-left', 'S'] + state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions) + assert state.is_final() + c0 = cost_history.pop(0) + assert c0['S'] == 0.0 + c1 = cost_history.pop(0) + assert c1['L-left'] == 0.0 + assert c1['R-right'] != 0.0 + c2 = cost_history.pop(0) + assert c2['R-right'] != 0.0 + assert c2['B-ROOT'] == 0.0 + assert c2['D'] == 0.0 + c3 = cost_history.pop(0) + assert c3['L-left'] == -1.0 + + +def test_reduce_is_gold_at_break(arc_eager, vocab): + words = ['a', 'b', 'c', 'd'] + heads = [1, 1, 3, 3] + deps = ['left', 'B-ROOT', 'left', 'B-ROOT'] + actions = ['S', 'R-right', 'B-ROOT', 'D', 'S', 'L-left', 'S'] + state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions) + assert state.is_final(), state.print_state(words) + c0 = cost_history.pop(0) + c1 = cost_history.pop(0) + c2 = cost_history.pop(0) + c3 = cost_history.pop(0) + assert c3['D'] == 0.0 annot_tuples = [ (0, 'When', 'WRB', 11, 'advmod', 'O'),