From 9db66ddd4867c0d5db0967193e7adb249460c31d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 21 Jun 2020 01:12:28 +0200 Subject: [PATCH] Update test_arc_eager_oracle --- spacy/tests/parser/test_arc_eager_oracle.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index 39f682a34..c2ab94500 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -13,8 +13,9 @@ from spacy.syntax.arc_eager import ArcEager def get_sequence_costs(M, words, heads, deps, transitions): doc = Doc(Vocab(), words=words) example = Example.from_dict(doc, {"heads": heads, "deps": deps}) - state = StateClass(doc) - M.preprocess_gold(example) + states, golds, _ = M.init_gold_batch([example]) + state = states[0] + gold = golds[0] cost_history = [] for gold_action in transitions: state_costs = {} @@ -23,6 +24,7 @@ def get_sequence_costs(M, words, heads, deps, transitions): state_costs[name] = M.get_cost(state, gold, i) M.transition(state, gold_action) cost_history.append(state_costs) + gold.update(state) return state, cost_history @@ -59,7 +61,6 @@ def gold(doc, words): raise NotImplementedError -@pytest.mark.xfail def test_oracle_four_words(arc_eager, vocab): words = ["a", "b", "c", "d"] heads = [1, 1, 3, 3] @@ -144,12 +145,11 @@ def test_get_oracle_actions(): parser.moves.add_action(1, "") parser.moves.add_action(1, "") parser.moves.add_action(4, "ROOT") + heads, deps = projectivize(heads, deps) for i, (head, dep) in enumerate(zip(heads, deps)): if head > i: parser.moves.add_action(2, dep) elif head < i: parser.moves.add_action(3, dep) - heads, deps = projectivize(heads, deps) example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}) - parser.moves.preprocess_gold(example) parser.moves.get_oracle_sequence(example)