mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 05:01:02 +03:00 
			
		
		
		
	Fix history features
This commit is contained in:
		
							parent
							
								
									b770f4e108
								
							
						
					
					
						commit
						278a4c17c6
					
				|  | @ -70,6 +70,8 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG | ||||||
| from . import _beam_utils | from . import _beam_utils | ||||||
| 
 | 
 | ||||||
| USE_HISTORY = True | USE_HISTORY = True | ||||||
|  | HIST_SIZE = 2 | ||||||
|  | HIST_DIMS = 16 | ||||||
| 
 | 
 | ||||||
| def get_templates(*args, **kwargs): | def get_templates(*args, **kwargs): | ||||||
|     return [] |     return [] | ||||||
|  | @ -262,13 +264,11 @@ cdef class Parser: | ||||||
| 
 | 
 | ||||||
|         with Model.use_device('cpu'): |         with Model.use_device('cpu'): | ||||||
|             if depth == 0: |             if depth == 0: | ||||||
|                 hist_size = 8 |  | ||||||
|                 nr_dim = 8 |  | ||||||
|                 if USE_HISTORY: |                 if USE_HISTORY: | ||||||
|                     upper = chain( |                     upper = chain( | ||||||
|                         HistoryFeatures(nr_class=nr_class, hist_size=hist_size, |                         HistoryFeatures(nr_class=nr_class, hist_size=HIST_SIZE, | ||||||
|                                         nr_dim=nr_dim), |                                         nr_dim=HIST_DIMS), | ||||||
|                         zero_init(Affine(nr_class, nr_class+hist_size*nr_dim, |                         zero_init(Affine(nr_class, nr_class+HIST_SIZE*HIST_DIMS, | ||||||
|                                           drop_factor=0.0))) |                                           drop_factor=0.0))) | ||||||
|                     upper.is_noop = False |                     upper.is_noop = False | ||||||
|                 else: |                 else: | ||||||
|  | @ -736,15 +736,13 @@ cdef class Parser: | ||||||
|         cdef StateClass state |         cdef StateClass state | ||||||
|         cdef int[500] is_valid # TODO: Unhack |         cdef int[500] is_valid # TODO: Unhack | ||||||
|         cdef float* c_scores = &scores[0, 0] |         cdef float* c_scores = &scores[0, 0] | ||||||
|         hists = [] |  | ||||||
|         for state in states: |         for state in states: | ||||||
|             self.moves.set_valid(is_valid, state.c) |             self.moves.set_valid(is_valid, state.c) | ||||||
|             guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1]) |             guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1]) | ||||||
|             action = self.moves.c[guess] |             action = self.moves.c[guess] | ||||||
|             action.do(state.c, action.label) |             action.do(state.c, action.label) | ||||||
|             c_scores += scores.shape[1] |             c_scores += scores.shape[1] | ||||||
|             hists.append(guess) |             state.c.push_hist(guess) | ||||||
|         return hists |  | ||||||
| 
 | 
 | ||||||
|     def get_batch_loss(self, states, golds, float[:, ::1] scores): |     def get_batch_loss(self, states, golds, float[:, ::1] scores): | ||||||
|         cdef StateClass state |         cdef StateClass state | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user