mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-29 23:17:59 +03:00 
			
		
		
		
	set
This commit is contained in:
		
							parent
							
								
									8cc49c5a03
								
							
						
					
					
						commit
						d8705e1291
					
				|  | @ -84,6 +84,7 @@ class AllreduceOptimizer: | |||
|         from cupy.cuda import nccl | ||||
|         self.optimizer = _create_optimizer(config_path) | ||||
|         self.communicator = communicator | ||||
|         self.weights_synced = set() | ||||
| 
 | ||||
|     def allreduce(self, tensor): | ||||
|         self.communicator.allReduce( | ||||
|  | @ -104,7 +105,11 @@ class AllreduceOptimizer: | |||
|         *, | ||||
|         lr_scale: float = 1.0, | ||||
|     ): | ||||
|         # weights = self.allreduce(weights) | ||||
|         if key not in self.weights_synced: | ||||
|             self.weights_synced.add(key) | ||||
|             weights = self.allreduce(weights) / self.communicator.size() | ||||
| 
 | ||||
| 
 | ||||
|         gradient = self.allreduce(gradient) | ||||
|         flat_weights, gradient = self.optimizer(key, weights, gradient, lr_scale=lr_scale) | ||||
|         return flat_weights, gradient | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user