mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Fix arc_eager oracle
This commit is contained in:
		
							parent
							
								
									a68d0e63f0
								
							
						
					
					
						commit
						420a986d15
					
				|  | @ -200,7 +200,6 @@ cdef class ArcEagerGold: | |||
|         sent_starts = example.get_aligned("SENT_START") | ||||
|         assert len(heads) == len(labels) == len(sent_starts) | ||||
|         self.c = create_gold_state(self.mem, stcls, heads, labels, sent_starts) | ||||
|         self.update(stcls) | ||||
| 
 | ||||
|     def update(self, StateClass stcls): | ||||
|         update_gold_state(&self.c, stcls) | ||||
|  | @ -577,17 +576,12 @@ cdef class ArcEager(TransitionSystem): | |||
|     def is_gold_parse(self, StateClass state, gold): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def has_gold(self, gold, start=0, end=None): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def preprocess_gold(self, example): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def init_gold_batch(self, examples): | ||||
|         examples = [eg for eg in examples if self.has_gold(eg)] | ||||
|         states = self.init_batch([eg.predicted for eg in examples]) | ||||
|         keeps = [i for i, s in enumerate(states) if not s.is_final()] | ||||
|         states = [states[i] for i in keeps] | ||||
|         golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps] | ||||
|         states = [states[i] for i in keeps] | ||||
|         for gold in golds: | ||||
|             self._replace_unseen_labels(gold) | ||||
|         n_steps = sum([len(s.queue) * 4 for s in states]) | ||||
|  | @ -690,6 +684,9 @@ cdef class ArcEager(TransitionSystem): | |||
|         doc.is_parsed = True | ||||
|         set_children_from_heads(doc.c, doc.length) | ||||
| 
 | ||||
|     def has_gold(self, Example eg): | ||||
|         return eg.y.is_parsed | ||||
| 
 | ||||
|     cdef int set_valid(self, int* output, const StateC* st) nogil: | ||||
|         cdef bint[N_MOVES] is_valid | ||||
|         is_valid[SHIFT] = Shift.is_valid(st, 0) | ||||
|  | @ -736,21 +733,29 @@ cdef class ArcEager(TransitionSystem): | |||
|             raise ValueError | ||||
| 
 | ||||
|     def get_oracle_sequence(self, Example example): | ||||
|         cdef StateClass state | ||||
|         cdef ArcEagerGold gold | ||||
|         states, golds, n_steps = self.init_gold_batch([example]) | ||||
|         if not golds: | ||||
|             return [] | ||||
| 
 | ||||
|         cdef Pool mem = Pool() | ||||
|         # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc | ||||
|         assert self.n_moves > 0 | ||||
|         costs = <float*>mem.alloc(self.n_moves, sizeof(float)) | ||||
|         is_valid = <int*>mem.alloc(self.n_moves, sizeof(int)) | ||||
| 
 | ||||
|         cdef StateClass state | ||||
|         cdef ArcEagerGold gold | ||||
|         states, golds, n_steps = self.init_gold_batch([example]) | ||||
|         state = states[0] | ||||
|         gold = golds[0] | ||||
|         history = [] | ||||
|         debug_log = [] | ||||
|         failed = False | ||||
|         while not state.is_final(): | ||||
|             self.set_costs(is_valid, costs, state, gold) | ||||
|             try: | ||||
|                 self.set_costs(is_valid, costs, state, gold) | ||||
|             except ValueError: | ||||
|                 failed = True | ||||
|                 break | ||||
|             for i in range(self.n_moves): | ||||
|                 if is_valid[i] and costs[i] <= 0: | ||||
|                     action = self.c[i] | ||||
|  | @ -766,36 +771,39 @@ cdef class ArcEager(TransitionSystem): | |||
|                     action.do(state.c, action.label) | ||||
|                     break | ||||
|             else: | ||||
|                 print("Actions") | ||||
|                 for i in range(self.n_moves): | ||||
|                     print(self.get_class_name(i)) | ||||
|                 print("Gold") | ||||
|                 for token in example.y: | ||||
|                     print(token.i, token.text, token.dep_, token.head.text) | ||||
|                 aligned_heads, aligned_labels = example.get_aligned_parse() | ||||
|                 print("Aligned heads") | ||||
|                 for i, head in enumerate(aligned_heads): | ||||
|                     print(example.x[i], example.x[head] if head is not None else "__") | ||||
|                 failed = False | ||||
|                 break | ||||
|         if failed: | ||||
|             print("Actions") | ||||
|             for i in range(self.n_moves): | ||||
|                 print(self.get_class_name(i)) | ||||
|             print("Gold") | ||||
|             for token in example.y: | ||||
|                 print(token.i, token.text, token.dep_, token.head.text) | ||||
|             aligned_heads, aligned_labels = example.get_aligned_parse() | ||||
|             print("Aligned heads") | ||||
|             for i, head in enumerate(aligned_heads): | ||||
|                 print(example.x[i], example.x[head] if head is not None else "__") | ||||
| 
 | ||||
|                 print("Predicted tokens") | ||||
|                 print([(w.i, w.text) for w in example.x]) | ||||
|                 s0 = state.S(0) | ||||
|                 b0 = state.B(0) | ||||
|                 debug_log.append(" ".join(( | ||||
|                     "?", | ||||
|                     "S0=", (example.x[s0].text if s0 >= 0 else "-"), | ||||
|                     "B0=", (example.x[b0].text if b0 >= 0 else "-"), | ||||
|                     "S0 head?", str(state.has_head(state.S(0))), | ||||
|                 ))) | ||||
|                 s0 = state.S(0) | ||||
|                 b0 = state.B(0) | ||||
|                 print("\n".join(debug_log)) | ||||
|                 print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0)) | ||||
|                 print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0)) | ||||
|                 print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0)) | ||||
|                 print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0)) | ||||
|                 print("b0", b0, "gold.heads[s0]", gold.c.heads[s0]) | ||||
|                 print("Stack", [example.x[i] for i in state.stack]) | ||||
|                 print("Buffer", [example.x[i] for i in state.queue]) | ||||
|                 raise ValueError(Errors.E024) | ||||
|             print("Predicted tokens") | ||||
|             print([(w.i, w.text) for w in example.x]) | ||||
|             s0 = state.S(0) | ||||
|             b0 = state.B(0) | ||||
|             debug_log.append(" ".join(( | ||||
|                 "?", | ||||
|                 "S0=", (example.x[s0].text if s0 >= 0 else "-"), | ||||
|                 "B0=", (example.x[b0].text if b0 >= 0 else "-"), | ||||
|                 "S0 head?", str(state.has_head(state.S(0))), | ||||
|             ))) | ||||
|             s0 = state.S(0) | ||||
|             b0 = state.B(0) | ||||
|             print("\n".join(debug_log)) | ||||
|             print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0)) | ||||
|             print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0)) | ||||
|             print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0)) | ||||
|             print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0)) | ||||
|             print("b0", b0, "gold.heads[s0]", gold.c.heads[s0]) | ||||
|             print("Stack", [example.x[i] for i in state.stack]) | ||||
|             print("Buffer", [example.x[i] for i in state.queue]) | ||||
|             raise ValueError(Errors.E024) | ||||
|         return history | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user