mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	sync
This commit is contained in:
		
							parent
							
								
									fdc9242bc1
								
							
						
					
					
						commit
						26c975ec66
					
				|  | @ -1,23 +1,44 @@ | |||
| """Parameter Server distributed training with Ray.""" | ||||
| 
 | ||||
| import threading | ||||
| import ray | ||||
| from wasabi import msg | ||||
| from .. import util | ||||
| from spacy.cli.ray_utils import create_optimizer | ||||
| 
 | ||||
| class OptimizerWorker: | ||||
|     def __init__(self, config_path, world_size, sync=True): | ||||
|         self.optimizer = _create_optimizer(config_path) | ||||
|     def __init__(self, config_path, world_size): | ||||
|         self.optimizer = create_optimizer(config_path) | ||||
|         self.new_weights = None | ||||
|         self.barrier = threading.Barrier(world_size) | ||||
|         self.lock = threading.Lock() | ||||
|         self.waiting = 0 | ||||
|         self.weights_dict = {} | ||||
|         self.grad_dict = {} | ||||
|         self.world_size = world_size | ||||
|         self.sync = sync | ||||
| 
 | ||||
|     def call(self, rank, key, weights, gradient, *, lr_scale=1.0): | ||||
|         if key not in self.weights_dict: | ||||
|     def call(self, key, weights, gradient, *, lr_scale=1.0): | ||||
|         self.lock.acquire() | ||||
| 
 | ||||
|         if self.waiting < self.world_size - 1: | ||||
|             if self.waiting == 0: | ||||
|                 self.gradient[key] = gradient.copy() | ||||
|                 self.weights_dict[key] = weights.copy() | ||||
|             else: | ||||
|                 self.gradient[key] += gradient | ||||
|             self.waiting = self.barrier.n_waiting + 1 | ||||
|             self.lock.release() | ||||
|             self.barrier.wait() | ||||
|         else: | ||||
|             self.gradient[key] += gradient | ||||
|             self.lock.release() | ||||
|             self.gradient[key] /= self.world_size | ||||
|             new_weights, new_grads = self.optimizer( | ||||
|             key, self.weights_dict[key], gradient.copy(), lr_scale=lr_scale) | ||||
|                 key, self.weights_dict[key], self.gradient[key], lr_scale=lr_scale) | ||||
|             self.weights_dict[key] = new_weights | ||||
|         return new_weights, new_grads | ||||
|             self.gradient[key] = new_grads | ||||
|             self.waiting = 0 | ||||
|             self.barrier.wait() | ||||
|         return self.weights_dict[key], self.gradient[key] | ||||
| 
 | ||||
|     def fetch(self): | ||||
|         return self.optimizer | ||||
|  | @ -40,7 +61,7 @@ class RayOptimizer: | |||
|         self.local_optimizer = ray.get(self.optimizer.fetch.remote()) | ||||
| 
 | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         weights, grads = ray.get(self.optimizer.call.remote(self.rank, *args, **kwargs)) | ||||
|         weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs)) | ||||
|         return weights.copy(), grads.copy() | ||||
| 
 | ||||
|     def __getattr__(self, name): | ||||
|  |  | |||
|  | @ -10,7 +10,7 @@ nccl = None | |||
| from typing import Dict, Optional, Union, Tuple, List, cast | ||||
| from thinc.types import FloatsXd | ||||
| 
 | ||||
| def _create_optimizer(config_path): | ||||
| def create_optimizer(config_path): | ||||
|     msg.info(f"Loading config from: {config_path}") | ||||
|     config = util.load_config(config_path, create_objects=False) | ||||
|     util.fix_random_seed(config["training"]["seed"])  # Fix this. | ||||
|  | @ -41,7 +41,7 @@ class AllreduceOptimizer: | |||
|         import cupy as cp | ||||
|         global nccl | ||||
|         from cupy.cuda import nccl | ||||
|         self.optimizer = _create_optimizer(config_path) | ||||
|         self.optimizer = create_optimizer(config_path) | ||||
|         self.communicator = communicator | ||||
|         self.weights_synced = set() | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user