diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index ac7fda292..0ef978bfa 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -17,13 +17,13 @@ def get_sequence_costs(M, words, heads, deps, transitions): gold = golds[0] cost_history = [] for gold_action in transitions: + gold.update(state) state_costs = {} for i in range(M.n_moves): name = M.class_name(i) state_costs[name] = M.get_cost(state, gold, i) M.transition(state, gold_action) cost_history.append(state_costs) - gold.update(state) return state, cost_history @@ -55,7 +55,7 @@ def test_oracle_four_words(arc_eager, vocab): assert state_costs[actions[i]] == 0.0, actions[i] for other_action, cost in state_costs.items(): if other_action != actions[i]: - assert cost >= 1 + assert cost >= 1, (i, other_action) annot_tuples = [