mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	* Adjust arc_eager oracle, so that recovering errors via non-monotonic actions gives negative cost. Need to test this with greedy parser.
This commit is contained in:
		
							parent
							
								
									0bf448461e
								
							
						
					
					
						commit
						1ee6b468a9
					
				|  | @ -1,3 +1,6 @@ | |||
| # cython: profile=True | ||||
| # cython: cdivision=True | ||||
| # cython: infer_types=True | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import ctypes | ||||
|  | @ -24,6 +27,7 @@ from .nonproj import PseudoProjectivity | |||
| DEF NON_MONOTONIC = True | ||||
| DEF USE_BREAK = True | ||||
| 
 | ||||
| 
 | ||||
| cdef weight_t MIN_SCORE = -90000 | ||||
| 
 | ||||
| # Break transition from here | ||||
|  | @ -65,10 +69,12 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no | |||
| cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: | ||||
|     cdef weight_t cost = 0 | ||||
|     cdef int i, B_i | ||||
|     # Count number of words in buffer with deendencies to/from the target. | ||||
|     for i in range(stcls.buffer_length()): | ||||
|         B_i = stcls.B(i) | ||||
|         cost += gold.heads[B_i] == target | ||||
|         cost += gold.heads[target] == B_i | ||||
|         # TODO: Should re-examine this for German --- it assumes projectivity. | ||||
|         if gold.heads[B_i] == B_i or gold.heads[B_i] < target: | ||||
|             break | ||||
|     if Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0: | ||||
|  | @ -154,7 +160,18 @@ cdef class Reduce: | |||
| 
 | ||||
|     @staticmethod | ||||
|     cdef inline weight_t move_cost(StateClass st, const GoldParseC* gold) nogil: | ||||
|         return pop_cost(st, gold, st.S(0)) | ||||
|         cost = pop_cost(st, gold, st.S(0)) | ||||
|         if not st.has_head(st.S(0)): | ||||
|             # Decrement cost for the arcs we save | ||||
|             for i in range(1, st.stack_depth()): | ||||
|                 S_i = st.S(i) | ||||
|                 if gold.heads[st.S(0)] == S_i: | ||||
|                     cost -= 1 | ||||
|                 if gold.heads[S_i] == st.S(0): | ||||
|                     cost -= 1 | ||||
|             if Break.is_valid(st.c, -1) and Break.move_cost(st, gold) == 0: | ||||
|                 cost -= 1 | ||||
|         return cost | ||||
| 
 | ||||
|     @staticmethod | ||||
|     cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, int label) nogil: | ||||
|  | @ -180,7 +197,8 @@ cdef class LeftArc: | |||
|     cdef inline weight_t move_cost(StateClass s, const GoldParseC* gold) nogil: | ||||
|         cdef weight_t cost = 0 | ||||
|         if arc_is_gold(gold, s.B(0), s.S(0)): | ||||
|             return 0 | ||||
|             # Have a negative cost if we 'recover' from the wrong dependency | ||||
|             return 0 if not s.has_head(s.S(0)) else -1 | ||||
|         else: | ||||
|             # Account for deps we might lose between S0 and stack | ||||
|             if not s.has_head(s.S(0)): | ||||
|  | @ -407,7 +425,7 @@ cdef class ArcEager(TransitionSystem): | |||
|         cdef move_cost_func_t[N_MOVES] move_cost_funcs | ||||
|         cdef weight_t[N_MOVES] move_costs | ||||
|         for i in range(N_MOVES): | ||||
|             move_costs[i] = -1 | ||||
|             move_costs[i] = 9000 | ||||
|         move_cost_funcs[SHIFT] = Shift.move_cost | ||||
|         move_cost_funcs[REDUCE] = Reduce.move_cost | ||||
|         move_cost_funcs[LEFT] = LeftArc.move_cost | ||||
|  | @ -429,10 +447,10 @@ cdef class ArcEager(TransitionSystem): | |||
|                 is_valid[i] = True | ||||
|                 move = self.c[i].move | ||||
|                 label = self.c[i].label | ||||
|                 if move_costs[move] == -1: | ||||
|                 if move_costs[move] == 9000: | ||||
|                     move_costs[move] = move_cost_funcs[move](stcls, &gold.c) | ||||
|                 costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label) | ||||
|                 n_gold += costs[i] == 0 | ||||
|                 n_gold += costs[i] <= 0 | ||||
|             else: | ||||
|                 is_valid[i] = False | ||||
|                 costs[i] = 9000 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user