mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Update oracle tests for Split
This commit is contained in:
parent
e31ef9c7f6
commit
f7beefe9c1
|
@ -10,9 +10,9 @@ from ...syntax.stateclass import StateClass
|
|||
from ...syntax.arc_eager import ArcEager
|
||||
|
||||
|
||||
def get_sequence_costs(M, words, heads, deps, transitions):
|
||||
def get_sequence_costs(M, words, gold_words, heads, deps, transitions):
|
||||
doc = Doc(Vocab(), words=words)
|
||||
gold = GoldParse(doc, heads=heads, deps=deps)
|
||||
gold = GoldParse(doc, words=gold_words, heads=heads, deps=deps)
|
||||
state = StateClass(doc)
|
||||
M.preprocess_gold(gold)
|
||||
cost_history = []
|
||||
|
@ -21,6 +21,7 @@ def get_sequence_costs(M, words, heads, deps, transitions):
|
|||
for i in range(M.n_moves):
|
||||
name = M.class_name(i)
|
||||
state_costs[name] = M.get_cost(state, gold, i)
|
||||
|
||||
M.transition(state, gold_action)
|
||||
cost_history.append(state_costs)
|
||||
return state, cost_history
|
||||
|
@ -124,7 +125,7 @@ def test_oracle_four_words(arc_eager, vocab):
|
|||
heads = [1, 1, 3, 3]
|
||||
deps = ['left', 'ROOT', 'left', 'ROOT']
|
||||
actions = ['S', 'L-left', 'S', 'B-ROOT', 'S', 'L-left', 'S']
|
||||
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
|
||||
state, cost_history = get_sequence_costs(arc_eager, words, words, heads, deps, actions)
|
||||
assert state.is_final()
|
||||
for i, state_costs in enumerate(cost_history):
|
||||
# Check gold moves is 0 cost
|
||||
|
@ -138,7 +139,7 @@ def test_non_monotonic_sequence_four_words(arc_eager, vocab):
|
|||
heads = [1, 1, 3, 3]
|
||||
deps = ['left', 'B-ROOT', 'left', 'B-ROOT']
|
||||
actions = ['S', 'R-right', 'R-right', 'L-left', 'L-left', 'L-left', 'S']
|
||||
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
|
||||
state, cost_history = get_sequence_costs(arc_eager, words, words, heads, deps, actions)
|
||||
assert state.is_final()
|
||||
c0 = cost_history.pop(0)
|
||||
assert c0['S'] == 0.0
|
||||
|
@ -162,7 +163,7 @@ def test_oracle_at_sentence_break(arc_eager, vocab):
|
|||
heads = [1, 1, 3, 3]
|
||||
deps = ['left', 'B-ROOT', 'left', 'B-ROOT']
|
||||
actions = ['S', 'R-right', 'D', 'B-ROOT', 'S']
|
||||
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
|
||||
state, cost_history = get_sequence_costs(arc_eager, words, words, heads, deps, actions)
|
||||
assert not state.is_final(), state.print_state(words)
|
||||
c0 = cost_history.pop(0)
|
||||
c1 = cost_history.pop(0)
|
||||
|
@ -175,17 +176,30 @@ def test_oracle_at_sentence_break(arc_eager, vocab):
|
|||
|
||||
|
||||
def test_split_oracle(arc_eager, vocab):
|
||||
if arc_eager.max_split < 2:
|
||||
return
|
||||
gold_words = ['a', 'b', 'c']
|
||||
doc = Doc(vocab, words=['ab', 'c'])
|
||||
words = ['ab', 'c']
|
||||
doc = Doc(vocab, words=words)
|
||||
heads = [2, 2, 2]
|
||||
deps = ['dep', 'dep', 'ROOT']
|
||||
actions = ['P-1', 'S', 'L-dep', 'S', 'B-ROOT']
|
||||
actions = ['P-1', 'S', 'S', 'L-dep', 'L-dep', 'S', 'B-ROOT']
|
||||
gold = GoldParse(doc, words=gold_words, heads=heads, deps=deps)
|
||||
assert gold.heads == [[1, 1], 1]
|
||||
assert gold.labels == [['dep', 'dep'], 'ROOT']
|
||||
state = StateClass(doc)
|
||||
M = arc_eager
|
||||
M.add_action(5, '1')
|
||||
M.preprocess_gold(gold)
|
||||
assert gold.fused[0] == 1
|
||||
state, cost_history = get_sequence_costs(arc_eager, words, gold_words, heads, deps, actions)
|
||||
assert state.is_final()
|
||||
for i, state_costs in enumerate(cost_history):
|
||||
# Check gold moves is 0 cost
|
||||
assert state_costs[actions[i]] == 0.0, actions[i]
|
||||
for other_action, cost in state_costs.items():
|
||||
if other_action != actions[i]:
|
||||
assert cost >= 1, (i, other_action, actions[i])
|
||||
|
||||
|
||||
annot_tuples = [
|
||||
|
|
Loading…
Reference in New Issue
Block a user