mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	small-changes
This commit is contained in:
		
							parent
							
								
									a5a3ed722c
								
							
						
					
					
						commit
						fdc9242bc1
					
				|  | @ -5,11 +5,13 @@ from wasabi import msg | |||
| from .. import util | ||||
| 
 | ||||
| class OptimizerWorker: | ||||
|     def __init__(self, config_path): | ||||
|     def __init__(self, config_path, world_size, sync=True): | ||||
|         self.optimizer = _create_optimizer(config_path) | ||||
|         self.weights_dict = {} | ||||
|         self.world_size = world_size | ||||
|         self.sync = sync | ||||
| 
 | ||||
|     def call(self, key, weights, gradient, *, lr_scale=1.0): | ||||
|     def call(self, rank, key, weights, gradient, *, lr_scale=1.0): | ||||
|         if key not in self.weights_dict: | ||||
|             self.weights_dict[key] = weights.copy() | ||||
|         new_weights, new_grads = self.optimizer( | ||||
|  | @ -26,18 +28,19 @@ class OptimizerWorker: | |||
| class RayOptimizer: | ||||
|     local_optimizer = None | ||||
| 
 | ||||
|     def __init__(self, config_path, use_gpu): | ||||
|     def __init__(self, config_path, use_gpu, rank): | ||||
|         RemoteOptimizer = ray.remote(OptimizerWorker) | ||||
|         if use_gpu >= 0: | ||||
|             RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1) | ||||
|         self.optimizer = RemoteOptimizer.remote(config_path) | ||||
|         self.rank = rank | ||||
|         self.sync() | ||||
| 
 | ||||
|     def sync(self): | ||||
|         self.local_optimizer = ray.get(self.optimizer.fetch.remote()) | ||||
| 
 | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs)) | ||||
|         weights, grads = ray.get(self.optimizer.call.remote(self.rank, *args, **kwargs)) | ||||
|         return weights.copy(), grads.copy() | ||||
| 
 | ||||
|     def __getattr__(self, name): | ||||
|  |  | |||
|  | @ -69,7 +69,7 @@ class AllreduceOptimizer: | |||
|             weights = self.allreduce(weights) / self.communicator.size() | ||||
| 
 | ||||
| 
 | ||||
|         gradient = self.allreduce(gradient) | ||||
|         gradient = self.allreduce(gradient) / self.communicator.size() | ||||
|         flat_weights, gradient = self.optimizer(key, weights, gradient, lr_scale=lr_scale) | ||||
|         return flat_weights, gradient | ||||
| 
 | ||||
|  |  | |||
|  | @ -198,7 +198,7 @@ def train_cli( | |||
| 
 | ||||
|     if num_workers and num_workers > 1: | ||||
|         import ray | ||||
|         ray.init() | ||||
|         ray.init(address="auto") | ||||
|         if strategy == "ps": | ||||
|             from spacy.cli.ray_param_server import RayOptimizer | ||||
|             remote_train = ray.remote(setup_and_train) | ||||
|  | @ -401,8 +401,12 @@ def train( | |||
|     msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}") | ||||
|     print_row = setup_printer(training, nlp) | ||||
| 
 | ||||
|     tqdm_args = dict(total=training["eval_frequency"], leave=False) | ||||
|     global world_rank | ||||
|     if world_rank is not None: | ||||
|         tqdm_args["disable"] = bool(world_rank != 0) | ||||
|     try: | ||||
|         progress = tqdm.tqdm(total=training["eval_frequency"], leave=False) | ||||
|         progress = tqdm.tqdm(**tqdm_args) | ||||
|         for batch, info, is_best_checkpoint in training_step_iterator: | ||||
|             progress.update(1) | ||||
|             if is_best_checkpoint is not None: | ||||
|  | @ -411,7 +415,7 @@ def train( | |||
|                 if is_best_checkpoint and output_path is not None: | ||||
|                     update_meta(training, nlp, info) | ||||
|                     nlp.to_disk(output_path / "model-best") | ||||
|                 progress = tqdm.tqdm(total=training["eval_frequency"], leave=False) | ||||
|                 progress = tqdm.tqdm(**tqdm_args) | ||||
|             # Clean up the objects to faciliate garbage collection. | ||||
|             for eg in batch: | ||||
|                 eg.doc = None | ||||
|  | @ -437,6 +441,10 @@ def train( | |||
| 
 | ||||
| def create_train_batches(nlp, corpus, cfg): | ||||
|     epochs_todo = cfg.get("max_epochs", 0) | ||||
|     if world_rank is not None: | ||||
|         for i in range(world_rank): | ||||
|             # Increment random seed | ||||
|             random.random() | ||||
|     while True: | ||||
|         train_examples = list( | ||||
|             corpus.train_dataset( | ||||
|  | @ -452,16 +460,18 @@ def create_train_batches(nlp, corpus, cfg): | |||
|             raise ValueError(Errors.E988) | ||||
|         random.shuffle(train_examples) | ||||
| 
 | ||||
|         if world_size is not None: | ||||
|             # Taken from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py | ||||
|             num_samples = int(math.ceil(len(train_examples) * 1.0 / world_size)) | ||||
|             total_size = num_samples * world_size  # expected to overflow | ||||
|             train_examples += train_examples[:(total_size - len(train_examples))] | ||||
|             assert len(train_examples) == total_size | ||||
|         # # TODO: with large batches, this can be bad. | ||||
|         # if world_size is not None: | ||||
|         #     # Taken from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py | ||||
|         #     num_samples = int(math.ceil(len(train_examples) * 1.0 / world_size)) | ||||
|         #     total_size = num_samples * world_size  # expected to overflow | ||||
|         #     train_examples += train_examples[:(total_size - len(train_examples))] | ||||
|         #     assert len(train_examples) == total_size | ||||
| 
 | ||||
|             # subsample | ||||
|             train_examples = train_examples[world_rank:total_size:world_size] | ||||
|             assert len(train_examples) == num_samples | ||||
|         #     # subsample | ||||
|         #     train_examples = train_examples[world_rank:total_size:world_size] | ||||
|         #     assert len(train_examples) == num_samples | ||||
|         #     print(f"Reset epoch: Only using {num_samples} out of {total_size} samples") | ||||
| 
 | ||||
|         batches = util.minibatch_by_words( | ||||
|             train_examples, | ||||
|  | @ -474,7 +484,7 @@ def create_train_batches(nlp, corpus, cfg): | |||
|             yield first | ||||
|         except StopIteration: | ||||
|             raise ValueError(Errors.E986) | ||||
|         for batch in batches: | ||||
|         for i, batch in enumerate(batches): | ||||
|             yield batch | ||||
|         epochs_todo -= 1 | ||||
|         # We intentionally compare exactly to 0 here, so that max_epochs < 1 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user