mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	Work on parser oracle
This commit is contained in:
		
							parent
							
								
									914924a68b
								
							
						
					
					
						commit
						c58deb3546
					
				|  | @ -82,7 +82,7 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp | ||||||
|         gold_i = cand_to_gold[cand_i] |         gold_i = cand_to_gold[cand_i] | ||||||
|         if gold_i is not None: # Alignment found |         if gold_i is not None: # Alignment found | ||||||
|             ref_tok = example.y.c[gold_i] |             ref_tok = example.y.c[gold_i] | ||||||
|             gold_head = gold_to_cand[ref_tok.head + gold_i] |             gold_head = gold_to_cand[gold_i + ref_tok.head] | ||||||
|             if gold_head is not None: |             if gold_head is not None: | ||||||
|                 gs.heads[cand_i] = gold_head |                 gs.heads[cand_i] = gold_head | ||||||
|                 gs.labels[cand_i] = ref_tok.dep |                 gs.labels[cand_i] = ref_tok.dep | ||||||
|  | @ -106,17 +106,17 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp | ||||||
|     stack_words = set() |     stack_words = set() | ||||||
|     for i in range(stcls.stack_depth()): |     for i in range(stcls.stack_depth()): | ||||||
|         s_i = stcls.S(i) |         s_i = stcls.S(i) | ||||||
|         head = s_i + gs.heads[s_i] |         head = gs.heads[s_i] | ||||||
|         gs.n_kids_in_stack[head] += 1 |         gs.n_kids_in_stack[head] += 1 | ||||||
|         stack_words.add(s_i) |         stack_words.add(s_i) | ||||||
|     buffer_words = set() |     buffer_words = set() | ||||||
|     for i in range(stcls.buffer_length()): |     for i in range(stcls.buffer_length()): | ||||||
|         b_i = stcls.B(i) |         b_i = stcls.B(i) | ||||||
|         head = b_i + gs.heads[b_i] |         head = gs.heads[b_i] | ||||||
|         gs.n_kids_in_buffer[head] += 1 |         gs.n_kids_in_buffer[head] += 1 | ||||||
|         buffer_words.add(b_i) |         buffer_words.add(b_i) | ||||||
|     for i in range(gs.length): |     for i in range(gs.length): | ||||||
|         head = i + gs.heads[i] |         head = gs.heads[i] | ||||||
|         if head in stack_words: |         if head in stack_words: | ||||||
|             gs.state_bits[i] = set_state_flag( |             gs.state_bits[i] = set_state_flag( | ||||||
|                 gs.state_bits[i], |                 gs.state_bits[i], | ||||||
|  | @ -142,6 +142,58 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp | ||||||
|     return gs |     return gs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | cdef void update_gold_state(GoldParseStateC* gs, StateClass stcls) except *: | ||||||
|  |     for i in range(gs.length): | ||||||
|  |         gs.state_bits[i] = set_state_flag( | ||||||
|  |             gs.state_bits[i], | ||||||
|  |             HEAD_IN_BUFFER, | ||||||
|  |             0 | ||||||
|  |         ) | ||||||
|  |         gs.state_bits[i] = set_state_flag( | ||||||
|  |             gs.state_bits[i], | ||||||
|  |             HEAD_IN_STACK, | ||||||
|  |             0 | ||||||
|  |         ) | ||||||
|  |         gs.n_kids_in_stack[i] = 0 | ||||||
|  |         gs.n_kids_in_buffer[i] = 0 | ||||||
|  |     stack_words = set() | ||||||
|  |     for i in range(stcls.stack_depth()): | ||||||
|  |         s_i = stcls.S(i) | ||||||
|  |         head = gs.heads[s_i] | ||||||
|  |         gs.n_kids_in_stack[head] += 1 | ||||||
|  |         stack_words.add(s_i) | ||||||
|  |     buffer_words = set() | ||||||
|  |     for i in range(stcls.buffer_length()): | ||||||
|  |         b_i = stcls.B(i) | ||||||
|  |         head = gs.heads[b_i] | ||||||
|  |         gs.n_kids_in_buffer[head] += 1 | ||||||
|  |         buffer_words.add(b_i) | ||||||
|  |     for i in range(gs.length): | ||||||
|  |         head = gs.heads[i] | ||||||
|  |         if head in stack_words: | ||||||
|  |             gs.state_bits[i] = set_state_flag( | ||||||
|  |                 gs.state_bits[i], | ||||||
|  |                 HEAD_IN_STACK, | ||||||
|  |                 1 | ||||||
|  |             ) | ||||||
|  |             gs.state_bits[i] = set_state_flag( | ||||||
|  |                 gs.state_bits[i], | ||||||
|  |                 HEAD_IN_BUFFER, | ||||||
|  |                 0 | ||||||
|  |             ) | ||||||
|  |         elif head in buffer_words: | ||||||
|  |             gs.state_bits[i] = set_state_flag( | ||||||
|  |                 gs.state_bits[i], | ||||||
|  |                 HEAD_IN_STACK, | ||||||
|  |                 0 | ||||||
|  |             ) | ||||||
|  |             gs.state_bits[i] = set_state_flag( | ||||||
|  |                 gs.state_bits[i], | ||||||
|  |                 HEAD_IN_BUFFER, | ||||||
|  |                 1 | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| cdef class ArcEagerGold: | cdef class ArcEagerGold: | ||||||
|     cdef GoldParseStateC c |     cdef GoldParseStateC c | ||||||
|     cdef Pool mem |     cdef Pool mem | ||||||
|  | @ -150,6 +202,9 @@ cdef class ArcEagerGold: | ||||||
|         self.mem = Pool() |         self.mem = Pool() | ||||||
|         self.c = create_gold_state(self.mem, stcls, example) |         self.c = create_gold_state(self.mem, stcls, example) | ||||||
| 
 | 
 | ||||||
|  |     def update(self, StateClass stcls): | ||||||
|  |         update_gold_state(&self.c, stcls) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef int check_state_gold(char state_bits, char flag) nogil: | cdef int check_state_gold(char state_bits, char flag) nogil: | ||||||
|  | @ -319,22 +374,27 @@ cdef class LeftArc: | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: |     cdef inline weight_t cost(StateClass s, const void* _gold, attr_t label) nogil: | ||||||
|         gold = <const GoldParseStateC*>_gold |         gold = <const GoldParseStateC*>_gold | ||||||
|         return RightArc.move_cost(s, gold) + RightArc.label_cost(s, gold, label) |         return LeftArc.move_cost(s, gold) + LeftArc.label_cost(s, gold, label) | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef inline weight_t move_cost(StateClass s, const void* _gold) nogil: |     cdef inline weight_t move_cost(StateClass s, const GoldParseStateC* gold) nogil: | ||||||
|         gold = <const GoldParseStateC*>_gold |         cdef weight_t cost = 0 | ||||||
|         if arc_is_gold(gold, s.S(0), s.B(0)): |         s0 = s.S(0) | ||||||
|             return 0 |         b0 = s.B(0) | ||||||
|         elif s.c.shifted[s.B(0)]: |         if arc_is_gold(gold, b0, s0): | ||||||
|             return push_cost(s, gold, s.B(0)) |             # Have a negative cost if we 'recover' from the wrong dependency | ||||||
|  |             return 0 if not s.has_head(s0) else -1 | ||||||
|         else: |         else: | ||||||
|             return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0)) |             # Account for deps we might lose between S0 and stack | ||||||
|  |             if not s.has_head(s0): | ||||||
|  |                 cost += gold.n_kids_in_stack[s0] | ||||||
|  |                 if is_head_in_buffer(gold, s0): | ||||||
|  |                     cost += 1 | ||||||
|  |             return cost + pop_cost(s, gold, s.S(0)) + arc_cost(s, gold, s.B(0), s.S(0)) | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef weight_t label_cost(StateClass s, const void* _gold, attr_t label) nogil: |     cdef inline weight_t label_cost(StateClass s, const GoldParseStateC* gold, attr_t label) nogil: | ||||||
|         gold = <const GoldParseStateC*>_gold |         return arc_is_gold(gold, s.B(0), s.S(0)) and not label_is_gold(gold, s.B(0), s.S(0), label) | ||||||
|         return arc_is_gold(gold, s.S(0), s.B(0)) and not label_is_gold(gold, s.S(0), s.B(0), label) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef class RightArc: | cdef class RightArc: | ||||||
|  | @ -622,42 +682,17 @@ cdef class ArcEager(TransitionSystem): | ||||||
| 
 | 
 | ||||||
|     cdef int set_costs(self, int* is_valid, weight_t* costs, |     cdef int set_costs(self, int* is_valid, weight_t* costs, | ||||||
|                        StateClass stcls, gold) except -1: |                        StateClass stcls, gold) except -1: | ||||||
|         gold_state = (<ArcEagerGold>gold).c |         if not isinstance(gold, ArcEagerGold): | ||||||
|         cdef int i, move |             raise TypeError("Expected ArcEagerGold") | ||||||
|         cdef attr_t label |         cdef ArcEagerGold gold_ = gold | ||||||
|         cdef label_cost_func_t[N_MOVES] label_cost_funcs |         gold_.update(stcls) | ||||||
|         cdef move_cost_func_t[N_MOVES] move_cost_funcs |         gold_state = gold_.c | ||||||
|         cdef weight_t[N_MOVES] move_costs |  | ||||||
|         for i in range(N_MOVES): |  | ||||||
|             move_costs[i] = 9000 |  | ||||||
|         move_cost_funcs[SHIFT] = Shift.move_cost |  | ||||||
|         move_cost_funcs[REDUCE] = Reduce.move_cost |  | ||||||
|         move_cost_funcs[LEFT] = LeftArc.move_cost |  | ||||||
|         move_cost_funcs[RIGHT] = RightArc.move_cost |  | ||||||
|         move_cost_funcs[BREAK] = Break.move_cost |  | ||||||
| 
 |  | ||||||
|         label_cost_funcs[SHIFT] = Shift.label_cost |  | ||||||
|         label_cost_funcs[REDUCE] = Reduce.label_cost |  | ||||||
|         label_cost_funcs[LEFT] = LeftArc.label_cost |  | ||||||
|         label_cost_funcs[RIGHT] = RightArc.label_cost |  | ||||||
|         label_cost_funcs[BREAK] = Break.label_cost |  | ||||||
| 
 |  | ||||||
|         cdef attr_t* labels = gold_state.labels |  | ||||||
|         cdef int32_t* heads = gold_state.heads |  | ||||||
| 
 |  | ||||||
|         n_gold = 0 |         n_gold = 0 | ||||||
|         for i in range(self.n_moves): |         for i in range(self.n_moves): | ||||||
|             if self.c[i].is_valid(stcls.c, self.c[i].label): |             if self.c[i].is_valid(stcls.c, self.c[i].label): | ||||||
|                 is_valid[i] = True |                 is_valid[i] = True | ||||||
|                 move = self.c[i].move |                 costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) | ||||||
|                 label = self.c[i].label |                 n_gold += 1 | ||||||
|                 if move_costs[move] == 9000: |  | ||||||
|                     move_costs[move] = move_cost_funcs[move](stcls, &gold_state) |  | ||||||
|                 move_cost = move_costs[move] |  | ||||||
|                 label_cost = label_cost_funcs[move](stcls, &gold_state, label) |  | ||||||
|                 costs[i] = move_cost + label_cost |  | ||||||
|                 n_gold += costs[i] <= 0 |  | ||||||
|                 print(move, label, costs[i]) |  | ||||||
|             else: |             else: | ||||||
|                 is_valid[i] = False |                 is_valid[i] = False | ||||||
|                 costs[i] = 9000 |                 costs[i] = 9000 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user