mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-29 06:57:49 +03:00 
			
		
		
		
	Update test_arc_eager_oracle
This commit is contained in:
		
							parent
							
								
									7544c21f5b
								
							
						
					
					
						commit
						9db66ddd48
					
				|  | @ -13,8 +13,9 @@ from spacy.syntax.arc_eager import ArcEager | |||
| def get_sequence_costs(M, words, heads, deps, transitions): | ||||
|     doc = Doc(Vocab(), words=words) | ||||
|     example = Example.from_dict(doc, {"heads": heads, "deps": deps}) | ||||
|     state = StateClass(doc) | ||||
|     M.preprocess_gold(example) | ||||
|     states, golds, _ = M.init_gold_batch([example]) | ||||
|     state = states[0] | ||||
|     gold = golds[0] | ||||
|     cost_history = [] | ||||
|     for gold_action in transitions: | ||||
|         state_costs = {} | ||||
|  | @ -23,6 +24,7 @@ def get_sequence_costs(M, words, heads, deps, transitions): | |||
|             state_costs[name] = M.get_cost(state, gold, i) | ||||
|         M.transition(state, gold_action) | ||||
|         cost_history.append(state_costs) | ||||
|         gold.update(state) | ||||
|     return state, cost_history | ||||
| 
 | ||||
| 
 | ||||
|  | @ -59,7 +61,6 @@ def gold(doc, words): | |||
|         raise NotImplementedError | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.xfail | ||||
| def test_oracle_four_words(arc_eager, vocab): | ||||
|     words = ["a", "b", "c", "d"] | ||||
|     heads = [1, 1, 3, 3] | ||||
|  | @ -144,12 +145,11 @@ def test_get_oracle_actions(): | |||
|     parser.moves.add_action(1, "") | ||||
|     parser.moves.add_action(1, "") | ||||
|     parser.moves.add_action(4, "ROOT") | ||||
|     heads, deps = projectivize(heads, deps) | ||||
|     for i, (head, dep) in enumerate(zip(heads, deps)): | ||||
|         if head > i: | ||||
|             parser.moves.add_action(2, dep) | ||||
|         elif head < i: | ||||
|             parser.moves.add_action(3, dep) | ||||
|     heads, deps = projectivize(heads, deps) | ||||
|     example = Example.from_dict(doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}) | ||||
|     parser.moves.preprocess_gold(example) | ||||
|     parser.moves.get_oracle_sequence(example) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user