mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 13:41:21 +03:00 
			
		
		
		
	Fix dropout and learn rate in parser
This commit is contained in:
		
							parent
							
								
									b40bc20b12
								
							
						
					
					
						commit
						1a59db1c86
					
				|  | @ -39,6 +39,7 @@ from preshed.maps cimport map_get | |||
| from thinc.api import layerize, chain, noop, clone | ||||
| from thinc.neural import Model, Affine, ReLu, Maxout | ||||
| from thinc.neural._classes.selu import SELU | ||||
| from thinc.neural._classes.layernorm import LayerNorm | ||||
| from thinc.neural.ops import NumpyOps, CupyOps | ||||
| from thinc.neural.util import get_array_module | ||||
| 
 | ||||
|  | @ -467,7 +468,7 @@ cdef class Parser: | |||
|             docs = [docs] | ||||
|             golds = [golds] | ||||
|         if USE_FINE_TUNE: | ||||
|             my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=0.) | ||||
|             my_tokvecs, bp_my_tokvecs = self.model[0].begin_update(docs_tokvecs, drop=drop) | ||||
|             my_tokvecs = self.model[0].ops.flatten(my_tokvecs) | ||||
|             tokvecs += my_tokvecs | ||||
| 
 | ||||
|  | @ -496,13 +497,13 @@ cdef class Parser: | |||
|             scores, bp_scores = vec2scores.begin_update(vector, drop=drop) | ||||
| 
 | ||||
|             d_scores = self.get_batch_loss(states, golds, scores) | ||||
|             d_vector = bp_scores(d_scores / d_scores.shape[0], sgd=sgd) | ||||
|             d_vector = bp_scores(d_scores, sgd=sgd) | ||||
|             if drop != 0: | ||||
|                 d_vector *= mask | ||||
| 
 | ||||
|             if isinstance(self.model[0].ops, CupyOps) \ | ||||
|             and not isinstance(token_ids, state2vec.ops.xp.ndarray): | ||||
|                 # Move token_ids and d_vector to CPU, asynchronously | ||||
|                 # Move token_ids and d_vector to GPU, asynchronously | ||||
|                 backprops.append(( | ||||
|                     get_async(cuda_stream, token_ids), | ||||
|                     get_async(cuda_stream, d_vector), | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user