mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 17:10:36 +03:00
Update test_arc_eager_oracle
This commit is contained in:
parent
7544c21f5b
commit
9db66ddd48
|
@ -13,8 +13,9 @@ from spacy.syntax.arc_eager import ArcEager
|
||||||
def get_sequence_costs(M, words, heads, deps, transitions):
|
def get_sequence_costs(M, words, heads, deps, transitions):
|
||||||
doc = Doc(Vocab(), words=words)
|
doc = Doc(Vocab(), words=words)
|
||||||
example = Example.from_dict(doc, {"heads": heads, "deps": deps})
|
example = Example.from_dict(doc, {"heads": heads, "deps": deps})
|
||||||
state = StateClass(doc)
|
states, golds, _ = M.init_gold_batch([example])
|
||||||
M.preprocess_gold(example)
|
state = states[0]
|
||||||
|
gold = golds[0]
|
||||||
cost_history = []
|
cost_history = []
|
||||||
for gold_action in transitions:
|
for gold_action in transitions:
|
||||||
state_costs = {}
|
state_costs = {}
|
||||||
|
@ -23,6 +24,7 @@ def get_sequence_costs(M, words, heads, deps, transitions):
|
||||||
state_costs[name] = M.get_cost(state, gold, i)
|
state_costs[name] = M.get_cost(state, gold, i)
|
||||||
M.transition(state, gold_action)
|
M.transition(state, gold_action)
|
||||||
cost_history.append(state_costs)
|
cost_history.append(state_costs)
|
||||||
|
gold.update(state)
|
||||||
return state, cost_history
|
return state, cost_history
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,7 +61,6 @@ def gold(doc, words):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
def test_oracle_four_words(arc_eager, vocab):
|
def test_oracle_four_words(arc_eager, vocab):
|
||||||
words = ["a", "b", "c", "d"]
|
words = ["a", "b", "c", "d"]
|
||||||
heads = [1, 1, 3, 3]
|
heads = [1, 1, 3, 3]
|
||||||
|
@ -144,12 +145,11 @@ def test_get_oracle_actions():
|
||||||
parser.moves.add_action(1, "")
|
parser.moves.add_action(1, "")
|
||||||
parser.moves.add_action(1, "")
|
parser.moves.add_action(1, "")
|
||||||
parser.moves.add_action(4, "ROOT")
|
parser.moves.add_action(4, "ROOT")
|
||||||
|
heads, deps = projectivize(heads, deps)
|
||||||
for i, (head, dep) in enumerate(zip(heads, deps)):
|
for i, (head, dep) in enumerate(zip(heads, deps)):
|
||||||
if head > i:
|
if head > i:
|
||||||
parser.moves.add_action(2, dep)
|
parser.moves.add_action(2, dep)
|
||||||
elif head < i:
|
elif head < i:
|
||||||
parser.moves.add_action(3, dep)
|
parser.moves.add_action(3, dep)
|
||||||
heads, deps = projectivize(heads, deps)
|
|
||||||
example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps})
|
example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps})
|
||||||
parser.moves.preprocess_gold(example)
|
|
||||||
parser.moves.get_oracle_sequence(example)
|
parser.moves.get_oracle_sequence(example)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user