mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Improve correctness of minibatching
This commit is contained in:
		
							parent
							
								
									dc07d72d80
								
							
						
					
					
						commit
						a1d4c97fb7
					
				|  | @ -427,7 +427,7 @@ cdef class Parser: | |||
| 
 | ||||
|         cuda_stream = get_cuda_stream() | ||||
| 
 | ||||
|         states, golds, max_length = self._init_gold_batch(docs, golds) | ||||
|         states, golds, max_steps = self._init_gold_batch(docs, golds) | ||||
|         state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, | ||||
|                                                       0.0) | ||||
|         todo = [(s, g) for (s, g) in zip(states, golds) | ||||
|  | @ -438,6 +438,7 @@ cdef class Parser: | |||
|         backprops = [] | ||||
|         d_tokvecs = state2vec.ops.allocate(tokvecs.shape) | ||||
|         cdef float loss = 0. | ||||
|         n_steps = 0 | ||||
|         while todo: | ||||
|             states, golds = zip(*todo) | ||||
| 
 | ||||
|  | @ -467,7 +468,8 @@ cdef class Parser: | |||
|             todo = [st for st in todo if not st[0].is_final()] | ||||
|             if losses is not None: | ||||
|                 losses[self.name] += (d_scores**2).sum() | ||||
|             if len(backprops) >= (max_length * 2): | ||||
|             n_steps += 1 | ||||
|             if n_steps >= max_steps: | ||||
|                 break | ||||
|         self._make_updates(d_tokvecs, | ||||
|             backprops, sgd, cuda_stream) | ||||
|  | @ -482,7 +484,8 @@ cdef class Parser: | |||
|             StateClass state | ||||
|             Transition action | ||||
|         whole_states = self.moves.init_batch(whole_docs) | ||||
|         max_length = max(5, min(20, min([len(doc) for doc in whole_docs]))) | ||||
|         max_length = max(5, min(50, min([len(doc) for doc in whole_docs]))) | ||||
|         max_moves = 0 | ||||
|         states = [] | ||||
|         golds = [] | ||||
|         for doc, state, gold in zip(whole_docs, whole_states, whole_golds): | ||||
|  | @ -493,16 +496,20 @@ cdef class Parser: | |||
|             start = 0 | ||||
|             while start < len(doc): | ||||
|                 state = state.copy() | ||||
|                 n_moves = 0 | ||||
|                 while state.B(0) < start and not state.is_final(): | ||||
|                     action = self.moves.c[oracle_actions.pop(0)] | ||||
|                     action.do(state.c, action.label) | ||||
|                     n_moves += 1 | ||||
|                 has_gold = self.moves.has_gold(gold, start=start, | ||||
|                                                end=start+max_length) | ||||
|                 if not state.is_final() and has_gold: | ||||
|                     states.append(state) | ||||
|                     golds.append(gold) | ||||
|                     max_moves = max(max_moves, n_moves) | ||||
|                 start += min(max_length, len(doc)-start) | ||||
|         return states, golds, max_length | ||||
|             max_moves = max(max_moves, len(oracle_actions)) | ||||
|         return states, golds, max_moves | ||||
| 
 | ||||
|     def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None): | ||||
|         # Tells CUDA to block, so our async copies complete. | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user