mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	* Add base TransitionSystem class. Still need to rethink how non-monotonic labelling will work for best_valid
This commit is contained in:
		
							parent
							
								
									01bc4d6815
								
							
						
					
					
						commit
						b063001596
					
				
							
								
								
									
										42
									
								
								spacy/syntax/transition_system.pxd
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								spacy/syntax/transition_system.pxd
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,42 @@ | ||||||
|  | from cymem.cymem cimport Pool | ||||||
|  | from thinc.typedefs cimport weight_t | ||||||
|  | 
 | ||||||
|  | from ..structs cimport TokenC | ||||||
|  | from ._state cimport State | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef struct Transition: | ||||||
|  |     int clas | ||||||
|  |     int move | ||||||
|  |     int label | ||||||
|  | 
 | ||||||
|  |     weight_t score | ||||||
|  |     int cost | ||||||
|  | 
 | ||||||
|  |     int (*get_cost)(const Transition* self, const State* state, const TokenC* gold) except -1 | ||||||
|  | 
 | ||||||
|  |     int (*is_valid)(const Transition* self, const State* state) except -1 | ||||||
|  |      | ||||||
|  |     int (*do)(const Transition* self, State* state) except -1 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ctypedef int (*get_cost_func_t)(const Transition* self, const State* state, | ||||||
|  |               const TokenC* gold) except -1 | ||||||
|  | 
 | ||||||
|  | ctypedef int (*is_valid_func_t)(const Transition* self, const State* state) except -1 | ||||||
|  |      | ||||||
|  | ctypedef int (*do_func_t)(const Transition* self, State* state) except -1 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef class TransitionSystem: | ||||||
|  |     cdef readonly dict label_ids | ||||||
|  |     cdef Pool mem | ||||||
|  |     cdef const Transition* c | ||||||
|  | 
 | ||||||
|  |     cdef Transition init_transition(self, int clas, int move, int label) except * | ||||||
|  | 
 | ||||||
|  |     cdef const Transition best_valid(self, const weight_t*, const State*) except * | ||||||
|  | 
 | ||||||
|  |     cdef const Transition best_gold(self, const weight_t*, const State*, | ||||||
|  |                                     const TokenC*) except * | ||||||
							
								
								
									
										54
									
								
								spacy/syntax/transition_system.pyx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								spacy/syntax/transition_system.pyx
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,54 @@ | ||||||
|  | from cymem.cymem cimport Pool | ||||||
|  | from ._state cimport State | ||||||
|  | from ..structs cimport TokenC | ||||||
|  | from thinc.typedefs cimport weight_t | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef weight_t MIN_SCORE = -90000 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef class TransitionSystem: | ||||||
|  |     def __init__(self, dict labels_by_action): | ||||||
|  |         self.mem = Pool() | ||||||
|  |         self.n_moves = sum(len(labels) for labels in labels_by_action.items()) | ||||||
|  |         moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition)) | ||||||
|  |         cdef int i = 0 | ||||||
|  |         self.label_ids = {} | ||||||
|  |         for action, label_strs in sorted(labels_by_action.items()): | ||||||
|  |             label_str = unicode(label_str) | ||||||
|  |             label_id = self.label_ids.setdefault(label_str, len(self.label_ids)) | ||||||
|  |             moves[i] = self.init_transition(i, action, label_id) | ||||||
|  |             i += 1 | ||||||
|  |         self.c = moves | ||||||
|  | 
 | ||||||
|  |     cdef Transition init_transition(self, int clas, int move, int label) except *: | ||||||
|  |         raise NotImplementedError | ||||||
|  | 
 | ||||||
|  |     cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: | ||||||
|  |         cdef Transition best | ||||||
|  |         cdef weight_t score = MIN_SCORE | ||||||
|  |         cdef int i | ||||||
|  |         for i in range(self.n_moves): | ||||||
|  |             if scores[i] > score and self.c[i].is_valid(&self.c[i], s): | ||||||
|  |                 best = self.c[i] | ||||||
|  |                 score = scores[i] | ||||||
|  |         # Label Shift moves with the best Right-Arc label, for non-monotonic | ||||||
|  |         # actions | ||||||
|  |         #if best.move == SHIFT: | ||||||
|  |         #    score = MIN_SCORE | ||||||
|  |         #    for i in range(self.n_moves): | ||||||
|  |         #        if self.c[i].move == RIGHT and scores[i] > score: | ||||||
|  |         #            best.label = self.c[i].label | ||||||
|  |         #            score = scores[i] | ||||||
|  |         return best | ||||||
|  | 
 | ||||||
|  |     cdef Transition best_gold(self, const weight_t* scores, const State* s, | ||||||
|  |                               const TokenC* gold) except *: | ||||||
|  |         cdef Transition best | ||||||
|  |         cdef weight_t score = MIN_SCORE | ||||||
|  |         cdef int i | ||||||
|  |         for i in range(self.n_moves): | ||||||
|  |             if scores[i] > score and self.c[i].get_cost(&self.c[i], s, gold) == 0: | ||||||
|  |                 best = self.c[i] | ||||||
|  |                 score = scores[i] | ||||||
|  |         return best | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user