mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 15:37:29 +03:00 
			
		
		
		
	debug
This commit is contained in:
		
							parent
							
								
									d8705e1291
								
							
						
					
					
						commit
						d1de4b1ea9
					
				|  | @ -224,7 +224,18 @@ def train_cli( | ||||||
|                 optimizer = AllreduceOptimizer(config_path, worker.communicator) |                 optimizer = AllreduceOptimizer(config_path, worker.communicator) | ||||||
|                 train_args["remote_optimizer"] = optimizer |                 train_args["remote_optimizer"] = optimizer | ||||||
|                 return setup_and_train(True, train_args, worker.rank, worker.world_size) |                 return setup_and_train(True, train_args, worker.rank, worker.world_size) | ||||||
|  | 
 | ||||||
|             ray.get([w.execute.remote(train_fn) for w in workers]) |             ray.get([w.execute.remote(train_fn) for w in workers]) | ||||||
|  |         elif strategy == "debug": | ||||||
|  |             remote_train = ray.remote(setup_and_train) | ||||||
|  |             if use_gpu >= 0: | ||||||
|  |                 msg.info("Enabling GPU with Ray") | ||||||
|  |                 remote_train = remote_train.options(num_gpus=0.9) | ||||||
|  |             ray.get([remote_train.remote( | ||||||
|  |                 use_gpu, | ||||||
|  |                 train_args, | ||||||
|  |                 rank=rank, | ||||||
|  |                 total_workers=num_workers) for rank in range(num_workers)]) | ||||||
|         else: |         else: | ||||||
|             raise NotImplementedError |             raise NotImplementedError | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user