mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-28 06:31:12 +03:00 
			
		
		
		
	Update ArcEager oracle
Fix Break oracle
This commit is contained in:
		
							parent
							
								
									3354758351
								
							
						
					
					
						commit
						e9860daf4b
					
				|  | @ -53,6 +53,8 @@ cdef enum: | |||
|     HEAD_IN_STACK = 0 | ||||
|     HEAD_IN_BUFFER | ||||
|     HEAD_UNKNOWN | ||||
|     IS_SENT_START | ||||
|     SENT_START_UNKNOWN | ||||
| 
 | ||||
| 
 | ||||
| cdef struct GoldParseStateC: | ||||
|  | @ -76,6 +78,43 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp | |||
|     gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0])) | ||||
| 
 | ||||
|     heads, labels = example.get_aligned_parse(projectivize=True) | ||||
|     sent_starts = example.get_aligned("SENT_START") | ||||
|     for i, is_sent_start in enumerate(sent_starts): | ||||
|         if is_sent_start == True: | ||||
|             gs.state_bits[i] = set_state_flag( | ||||
|                 gs.state_bits[i], | ||||
|                 IS_SENT_START, | ||||
|                 1 | ||||
|             ) | ||||
|             gs.state_bits[i] = set_state_flag( | ||||
|                 gs.state_bits[i], | ||||
|                 SENT_START_UNKNOWN, | ||||
|                 0 | ||||
|             ) | ||||
|   | ||||
|         elif is_sent_start is None: | ||||
|             gs.state_bits[i] = set_state_flag( | ||||
|                 gs.state_bits[i], | ||||
|                 SENT_START_UNKNOWN, | ||||
|                 1 | ||||
|             ) | ||||
|             gs.state_bits[i] = set_state_flag( | ||||
|                 gs.state_bits[i], | ||||
|                 IS_SENT_START, | ||||
|                 0 | ||||
|             ) | ||||
|         else: | ||||
|             gs.state_bits[i] = set_state_flag( | ||||
|                 gs.state_bits[i], | ||||
|                 SENT_START_UNKNOWN, | ||||
|                 0 | ||||
|             ) | ||||
|             gs.state_bits[i] = set_state_flag( | ||||
|                 gs.state_bits[i], | ||||
|                 IS_SENT_START, | ||||
|                 0 | ||||
|             ) | ||||
|   | ||||
|     cdef TokenC ref_tok | ||||
|     for i, (head, label) in enumerate(zip(heads, labels)): | ||||
|         if head is not None: | ||||
|  | @ -220,6 +259,13 @@ cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil: | |||
| cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil: | ||||
|     return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN) | ||||
| 
 | ||||
| cdef int is_sent_start(const GoldParseStateC* gold, int i) nogil: | ||||
|     return check_state_gold(gold.state_bits[i], IS_SENT_START) | ||||
| 
 | ||||
| cdef int is_sent_start_unknown(const GoldParseStateC* gold, int i) nogil: | ||||
|     return check_state_gold(gold.state_bits[i], SENT_START_UNKNOWN) | ||||
| 
 | ||||
| 
 | ||||
| # Helper functions for the arc-eager oracle | ||||
| 
 | ||||
| cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil: | ||||
|  | @ -251,7 +297,7 @@ cdef weight_t arc_cost(StateClass stcls, const void* _gold, int head, int child) | |||
|     elif stcls.H(child) == gold.heads[child]: | ||||
|         return 1 | ||||
|     # Head in buffer | ||||
|     elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != 0: | ||||
|     elif is_head_in_buffer(gold, child): | ||||
|         return 1 | ||||
|     else: | ||||
|         return 0 | ||||
|  | @ -452,15 +498,15 @@ cdef class Break: | |||
|     @staticmethod | ||||
|     cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil: | ||||
|         gold = <const GoldParseStateC*>_gold | ||||
|         cdef weight_t cost = 0 | ||||
|         cdef int i, j, S_i, B_i | ||||
|         cost = 0 | ||||
|         for i in range(s.stack_depth()): | ||||
|             S_i = s.S(i) | ||||
|             cost += gold.n_kids_in_buffer[S_i] | ||||
|             if is_head_in_buffer(gold, S_i): | ||||
|                 cost += 1 | ||||
|         # Check for sentence boundary --- if it's here, we can't have any deps | ||||
|         # between stack and buffer, so rest of action is irrelevant. | ||||
|         # It's weird not to check the gold sentence boundaries but if we do, | ||||
|         # we can't account for "sunk costs", i.e. situations where we're already | ||||
|         # wrong. | ||||
|         s0_root = _get_root(s.S(0), gold) | ||||
|         b0_root = _get_root(s.B(0), gold) | ||||
|         if s0_root != b0_root or s0_root == -1 or b0_root == -1: | ||||
|  | @ -538,6 +584,7 @@ cdef class ArcEager(TransitionSystem): | |||
|                 for label, freq in list(label_freqs.items()): | ||||
|                     if freq < min_freq: | ||||
|                         label_freqs.pop(label) | ||||
|                         print("Removing", action, label, freq) | ||||
|         # Ensure these actions are present | ||||
|         actions[BREAK].setdefault('ROOT', 0) | ||||
|         if kwargs.get("learn_tokens") is True: | ||||
|  | @ -588,7 +635,7 @@ cdef class ArcEager(TransitionSystem): | |||
|         for i in range(self.n_moves): | ||||
|             if self.c[i].move == move and self.c[i].label == label: | ||||
|                 return self.c[i] | ||||
|         return Transition(clas=0, move=MISSING, label=0) | ||||
|         raise KeyError(f"Unknown transition: {name}") | ||||
| 
 | ||||
|     def move_name(self, int move, attr_t label): | ||||
|         label_str = self.strings[label] | ||||
|  | @ -691,12 +738,79 @@ cdef class ArcEager(TransitionSystem): | |||
|             if self.c[i].is_valid(stcls.c, self.c[i].label): | ||||
|                 is_valid[i] = True | ||||
|                 costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) | ||||
|                 n_gold += 1 | ||||
|                 n_gold += costs[i] <= 0 | ||||
|             else: | ||||
|                 is_valid[i] = False | ||||
|                 costs[i] = 9000 | ||||
|         if n_gold < 1: | ||||
|             raise ValueError | ||||
|             #failure_state = stcls.print_state([t.text for t in example]) | ||||
|             #raise ValueError( | ||||
|             #    Errors.E021.format(n_actions=self.n_moves, state=failure_state)) | ||||
|         #if n_gold < 1: | ||||
|         #    raise ValueError | ||||
|         #    #failure_state = stcls.print_state([t.text for t in example]) | ||||
|         #    #raise ValueError( | ||||
|         #    #    Errors.E021.format(n_actions=self.n_moves, state=failure_state)) | ||||
| 
 | ||||
|     def get_oracle_sequence(self, Example example): | ||||
|         cdef Pool mem = Pool() | ||||
|         # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc | ||||
|         assert self.n_moves > 0 | ||||
|         costs = <float*>mem.alloc(self.n_moves, sizeof(float)) | ||||
|         is_valid = <int*>mem.alloc(self.n_moves, sizeof(int)) | ||||
| 
 | ||||
|         cdef StateClass state | ||||
|         cdef ArcEagerGold gold | ||||
|         states, golds, n_steps = self.init_gold_batch([example]) | ||||
|         state = states[0] | ||||
|         gold = golds[0] | ||||
|         history = [] | ||||
|         debug_log = [] | ||||
|         while not state.is_final(): | ||||
|             self.set_costs(is_valid, costs, state, gold) | ||||
|             for i in range(self.n_moves): | ||||
|                 if is_valid[i] and costs[i] <= 0: | ||||
|                     action = self.c[i] | ||||
|                     history.append(i) | ||||
|                     s0 = state.S(0) | ||||
|                     b0 = state.B(0) | ||||
|                     debug_log.append(" ".join(( | ||||
|                         self.get_class_name(i), | ||||
|                         "S0=", (example.x[s0].text if s0 >= 0 else "__"), | ||||
|                         "B0=", (example.x[b0].text if b0 >= 0 else "__"), | ||||
|                         "S0 head?", str(state.has_head(state.S(0))), | ||||
|                     ))) | ||||
|                     action.do(state.c, action.label) | ||||
|                     break | ||||
|             else: | ||||
|                 print("Actions") | ||||
|                 for i in range(self.n_moves): | ||||
|                     print(self.get_class_name(i)) | ||||
|                 print("Gold") | ||||
|                 for token in example.y: | ||||
|                     print(token.i, token.text, token.dep_, token.head.text) | ||||
|                 aligned_heads, aligned_labels = example.get_aligned_parse() | ||||
|                 print("Aligned heads") | ||||
|                 for i, head in enumerate(aligned_heads): | ||||
|                     print(example.x[i], example.x[head] if head is not None else "__") | ||||
| 
 | ||||
|                 print("Predicted tokens") | ||||
|                 print([(w.i, w.text) for w in example.x]) | ||||
|                 s0 = state.S(0) | ||||
|                 b0 = state.B(0) | ||||
|                 debug_log.append(" ".join(( | ||||
|                     "?", | ||||
|                     "S0=", (example.x[s0].text if s0 >= 0 else "-"), | ||||
|                     "B0=", (example.x[b0].text if b0 >= 0 else "-"), | ||||
|                     "S0 head?", str(state.has_head(state.S(0))), | ||||
|                 ))) | ||||
|                 s0 = state.S(0) | ||||
|                 b0 = state.B(0) | ||||
|                 print("\n".join(debug_log)) | ||||
|                 print("Arc is gold B0, S0?", arc_is_gold(&gold.c, b0, s0)) | ||||
|                 print("Arc is gold S0, B0?", arc_is_gold(&gold.c, s0, b0)) | ||||
|                 print("is_head_unknown(s0)", is_head_unknown(&gold.c, s0)) | ||||
|                 print("is_head_unknown(b0)", is_head_unknown(&gold.c, b0)) | ||||
|                 print("b0", b0, "gold.heads[s0]", gold.c.heads[s0]) | ||||
|                 print("Stack", [example.x[i] for i in state.stack]) | ||||
|                 print("Buffer", [example.x[i] for i in state.queue]) | ||||
|                 raise ValueError(Errors.E024) | ||||
|         return history | ||||
| 
 | ||||
|   | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user