From f7beefe9c19d856990625a379ba34462bf33ad4d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 3 Apr 2018 15:44:58 +0200 Subject: [PATCH] Update oracle tests for Split --- spacy/tests/parser/test_arc_eager_oracle.py | 30 +++++++++++++++------ 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index a7b973206..bef0f1255 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -10,9 +10,9 @@ from ...syntax.stateclass import StateClass from ...syntax.arc_eager import ArcEager -def get_sequence_costs(M, words, heads, deps, transitions): +def get_sequence_costs(M, words, gold_words, heads, deps, transitions): doc = Doc(Vocab(), words=words) - gold = GoldParse(doc, heads=heads, deps=deps) + gold = GoldParse(doc, words=gold_words, heads=heads, deps=deps) state = StateClass(doc) M.preprocess_gold(gold) cost_history = [] @@ -21,6 +21,7 @@ def get_sequence_costs(M, words, heads, deps, transitions): for i in range(M.n_moves): name = M.class_name(i) state_costs[name] = M.get_cost(state, gold, i) + M.transition(state, gold_action) cost_history.append(state_costs) return state, cost_history @@ -124,7 +125,7 @@ def test_oracle_four_words(arc_eager, vocab): heads = [1, 1, 3, 3] deps = ['left', 'ROOT', 'left', 'ROOT'] actions = ['S', 'L-left', 'S', 'B-ROOT', 'S', 'L-left', 'S'] - state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions) + state, cost_history = get_sequence_costs(arc_eager, words, words, heads, deps, actions) assert state.is_final() for i, state_costs in enumerate(cost_history): # Check gold moves is 0 cost @@ -138,7 +139,7 @@ def test_non_monotonic_sequence_four_words(arc_eager, vocab): heads = [1, 1, 3, 3] deps = ['left', 'B-ROOT', 'left', 'B-ROOT'] actions = ['S', 'R-right', 'R-right', 'L-left', 'L-left', 'L-left', 'S'] - state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions) + state, cost_history = get_sequence_costs(arc_eager, words, words, heads, deps, actions) assert state.is_final() c0 = cost_history.pop(0) assert c0['S'] == 0.0 @@ -162,7 +163,7 @@ def test_oracle_at_sentence_break(arc_eager, vocab): heads = [1, 1, 3, 3] deps = ['left', 'B-ROOT', 'left', 'B-ROOT'] actions = ['S', 'R-right', 'D', 'B-ROOT', 'S'] - state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions) + state, cost_history = get_sequence_costs(arc_eager, words, words, heads, deps, actions) assert not state.is_final(), state.print_state(words) c0 = cost_history.pop(0) c1 = cost_history.pop(0) @@ -175,18 +176,31 @@ def test_oracle_at_sentence_break(arc_eager, vocab): def test_split_oracle(arc_eager, vocab): + if arc_eager.max_split < 2: + return gold_words = ['a', 'b', 'c'] - doc = Doc(vocab, words=['ab', 'c']) + words = ['ab', 'c'] + doc = Doc(vocab, words=words) heads = [2, 2, 2] deps = ['dep', 'dep', 'ROOT'] - actions = ['P-1', 'S', 'L-dep', 'S', 'B-ROOT'] + actions = ['P-1', 'S', 'S', 'L-dep', 'L-dep', 'S', 'B-ROOT'] gold = GoldParse(doc, words=gold_words, heads=heads, deps=deps) assert gold.heads == [[1, 1], 1] assert gold.labels == [['dep', 'dep'], 'ROOT'] state = StateClass(doc) M = arc_eager + M.add_action(5, '1') M.preprocess_gold(gold) - + assert gold.fused[0] == 1 + state, cost_history = get_sequence_costs(arc_eager, words, gold_words, heads, deps, actions) + assert state.is_final() + for i, state_costs in enumerate(cost_history): + # Check gold moves is 0 cost + assert state_costs[actions[i]] == 0.0, actions[i] + for other_action, cost in state_costs.items(): + if other_action != actions[i]: + assert cost >= 1, (i, other_action, actions[i]) + annot_tuples = [ (0, 'When', 'WRB', 11, 'advmod', 'O'),