Update test_arc_eager_oracle

This commit is contained in:
Matthew Honnibal 2020-06-21 01:12:28 +02:00
parent 7544c21f5b
commit 9db66ddd48

View File

@ -13,8 +13,9 @@ from spacy.syntax.arc_eager import ArcEager
def get_sequence_costs(M, words, heads, deps, transitions):
doc = Doc(Vocab(), words=words)
example = Example.from_dict(doc, {"heads": heads, "deps": deps})
state = StateClass(doc)
M.preprocess_gold(example)
states, golds, _ = M.init_gold_batch([example])
state = states[0]
gold = golds[0]
cost_history = []
for gold_action in transitions:
state_costs = {}
@ -23,6 +24,7 @@ def get_sequence_costs(M, words, heads, deps, transitions):
state_costs[name] = M.get_cost(state, gold, i)
M.transition(state, gold_action)
cost_history.append(state_costs)
gold.update(state)
return state, cost_history
@ -59,7 +61,6 @@ def gold(doc, words):
raise NotImplementedError
@pytest.mark.xfail
def test_oracle_four_words(arc_eager, vocab):
words = ["a", "b", "c", "d"]
heads = [1, 1, 3, 3]
@ -144,12 +145,11 @@ def test_get_oracle_actions():
parser.moves.add_action(1, "")
parser.moves.add_action(1, "")
parser.moves.add_action(4, "ROOT")
heads, deps = projectivize(heads, deps)
for i, (head, dep) in enumerate(zip(heads, deps)):
if head > i:
parser.moves.add_action(2, dep)
elif head < i:
parser.moves.add_action(3, dep)
heads, deps = projectivize(heads, deps)
example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps})
parser.moves.preprocess_gold(example)
parser.moves.get_oracle_sequence(example)