Update test_arc_eager_oracle

This commit is contained in:
Matthew Honnibal 2020-06-21 01:12:28 +02:00
parent 7544c21f5b
commit 9db66ddd48

View File

@ -13,8 +13,9 @@ from spacy.syntax.arc_eager import ArcEager
def get_sequence_costs(M, words, heads, deps, transitions): def get_sequence_costs(M, words, heads, deps, transitions):
doc = Doc(Vocab(), words=words) doc = Doc(Vocab(), words=words)
example = Example.from_dict(doc, {"heads": heads, "deps": deps}) example = Example.from_dict(doc, {"heads": heads, "deps": deps})
state = StateClass(doc) states, golds, _ = M.init_gold_batch([example])
M.preprocess_gold(example) state = states[0]
gold = golds[0]
cost_history = [] cost_history = []
for gold_action in transitions: for gold_action in transitions:
state_costs = {} state_costs = {}
@ -23,6 +24,7 @@ def get_sequence_costs(M, words, heads, deps, transitions):
state_costs[name] = M.get_cost(state, gold, i) state_costs[name] = M.get_cost(state, gold, i)
M.transition(state, gold_action) M.transition(state, gold_action)
cost_history.append(state_costs) cost_history.append(state_costs)
gold.update(state)
return state, cost_history return state, cost_history
@ -59,7 +61,6 @@ def gold(doc, words):
raise NotImplementedError raise NotImplementedError
@pytest.mark.xfail
def test_oracle_four_words(arc_eager, vocab): def test_oracle_four_words(arc_eager, vocab):
words = ["a", "b", "c", "d"] words = ["a", "b", "c", "d"]
heads = [1, 1, 3, 3] heads = [1, 1, 3, 3]
@ -144,12 +145,11 @@ def test_get_oracle_actions():
parser.moves.add_action(1, "") parser.moves.add_action(1, "")
parser.moves.add_action(1, "") parser.moves.add_action(1, "")
parser.moves.add_action(4, "ROOT") parser.moves.add_action(4, "ROOT")
heads, deps = projectivize(heads, deps)
for i, (head, dep) in enumerate(zip(heads, deps)): for i, (head, dep) in enumerate(zip(heads, deps)):
if head > i: if head > i:
parser.moves.add_action(2, dep) parser.moves.add_action(2, dep)
elif head < i: elif head < i:
parser.moves.add_action(3, dep) parser.moves.add_action(3, dep)
heads, deps = projectivize(heads, deps)
example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}) example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps})
parser.moves.preprocess_gold(example)
parser.moves.get_oracle_sequence(example) parser.moves.get_oracle_sequence(example)