mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	* Adjust arc_eager oracle, so that recovering errors via non-monotonic actions gives negative cost. Need to test this with greedy parser.
This commit is contained in:
		
							parent
							
								
									0bf448461e
								
							
						
					
					
						commit
						1ee6b468a9
					
				|  | @ -1,3 +1,6 @@ | ||||||
|  | # cython: profile=True | ||||||
|  | # cython: cdivision=True | ||||||
|  | # cython: infer_types=True | ||||||
| from __future__ import unicode_literals | from __future__ import unicode_literals | ||||||
| 
 | 
 | ||||||
| import ctypes | import ctypes | ||||||
|  | @ -24,6 +27,7 @@ from .nonproj import PseudoProjectivity | ||||||
| DEF NON_MONOTONIC = True | DEF NON_MONOTONIC = True | ||||||
| DEF USE_BREAK = True | DEF USE_BREAK = True | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| cdef weight_t MIN_SCORE = -90000 | cdef weight_t MIN_SCORE = -90000 | ||||||
| 
 | 
 | ||||||
| # Break transition from here | # Break transition from here | ||||||
|  | @ -65,10 +69,12 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no | ||||||
| cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: | cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: | ||||||
|     cdef weight_t cost = 0 |     cdef weight_t cost = 0 | ||||||
|     cdef int i, B_i |     cdef int i, B_i | ||||||
|  |     # Count number of words in buffer with deendencies to/from the target. | ||||||
|     for i in range(stcls.buffer_length()): |     for i in range(stcls.buffer_length()): | ||||||
|         B_i = stcls.B(i) |         B_i = stcls.B(i) | ||||||
|         cost += gold.heads[B_i] == target |         cost += gold.heads[B_i] == target | ||||||
|         cost += gold.heads[target] == B_i |         cost += gold.heads[target] == B_i | ||||||
|  |         # TODO: Should re-examine this for German --- it assumes projectivity. | ||||||
|         if gold.heads[B_i] == B_i or gold.heads[B_i] < target: |         if gold.heads[B_i] == B_i or gold.heads[B_i] < target: | ||||||
|             break |             break | ||||||
|     if Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0: |     if Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0: | ||||||
|  | @ -154,7 +160,18 @@ cdef class Reduce: | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil: |     cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil: | ||||||
|         return pop_cost(st, gold, st.S(0)) |         cost = pop_cost(st, gold, st.S(0)) | ||||||
|  |         if not st.has_head(st.S(0)): | ||||||
|  |             # Decrement cost for the arcs we save | ||||||
|  |             for i in range(1, st.stack_depth()): | ||||||
|  |                 S_i = st.S(i) | ||||||
|  |                 if gold.heads[st.S(0)] == S_i: | ||||||
|  |                     cost -= 1 | ||||||
|  |                 if gold.heads[S_i] == st.S(0): | ||||||
|  |                     cost -= 1 | ||||||
|  |             if Break.is_valid(st.c, -1) and Break.move_cost(st, gold) == 0: | ||||||
|  |                 cost -= 1 | ||||||
|  |         return cost | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil: |     cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil: | ||||||
|  | @ -180,7 +197,8 @@ cdef class LeftArc: | ||||||
|     cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: |     cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: | ||||||
|         cdef weight_t cost = 0 |         cdef weight_t cost = 0 | ||||||
|         if arc_is_gold(gold, s.B(0), s.S(0)): |         if arc_is_gold(gold, s.B(0), s.S(0)): | ||||||
|             return 0 |             # Have a negative cost if we 'recover' from the wrong dependency | ||||||
|  |             return 0 if not s.has_head(s.S(0)) else -1 | ||||||
|         else: |         else: | ||||||
|             # Account for deps we might lose between S0 and stack |             # Account for deps we might lose between S0 and stack | ||||||
|             if not s.has_head(s.S(0)): |             if not s.has_head(s.S(0)): | ||||||
|  | @ -407,7 +425,7 @@ cdef class ArcEager(TransitionSystem): | ||||||
|         cdef move_cost_func_t[N_MOVES] move_cost_funcs |         cdef move_cost_func_t[N_MOVES] move_cost_funcs | ||||||
|         cdef weight_t[N_MOVES] move_costs |         cdef weight_t[N_MOVES] move_costs | ||||||
|         for i in range(N_MOVES): |         for i in range(N_MOVES): | ||||||
|             move_costs[i] = -1 |             move_costs[i] = 9000 | ||||||
|         move_cost_funcs[SHIFT] = Shift.move_cost |         move_cost_funcs[SHIFT] = Shift.move_cost | ||||||
|         move_cost_funcs[REDUCE] = Reduce.move_cost |         move_cost_funcs[REDUCE] = Reduce.move_cost | ||||||
|         move_cost_funcs[LEFT] = LeftArc.move_cost |         move_cost_funcs[LEFT] = LeftArc.move_cost | ||||||
|  | @ -429,10 +447,10 @@ cdef class ArcEager(TransitionSystem): | ||||||
|                 is_valid[i] = True |                 is_valid[i] = True | ||||||
|                 move = self.c[i].move |                 move = self.c[i].move | ||||||
|                 label = self.c[i].label |                 label = self.c[i].label | ||||||
|                 if move_costs[move] == -1: |                 if move_costs[move] == 9000: | ||||||
|                     move_costs[move] = move_cost_funcs[move](stcls, &gold.c) |                     move_costs[move] = move_cost_funcs[move](stcls, &gold.c) | ||||||
|                 costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label) |                 costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label) | ||||||
|                 n_gold += costs[i] == 0 |                 n_gold += costs[i] <= 0 | ||||||
|             else: |             else: | ||||||
|                 is_valid[i] = False |                 is_valid[i] = False | ||||||
|                 costs[i] = 9000 |                 costs[i] = 9000 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user