mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-07 15:10:34 +03:00
debug
This commit is contained in:
parent
d8705e1291
commit
d1de4b1ea9
|
@ -224,7 +224,18 @@ def train_cli(
|
||||||
optimizer = AllreduceOptimizer(config_path, worker.communicator)
|
optimizer = AllreduceOptimizer(config_path, worker.communicator)
|
||||||
train_args["remote_optimizer"] = optimizer
|
train_args["remote_optimizer"] = optimizer
|
||||||
return setup_and_train(True, train_args, worker.rank, worker.world_size)
|
return setup_and_train(True, train_args, worker.rank, worker.world_size)
|
||||||
|
|
||||||
ray.get([w.execute.remote(train_fn) for w in workers])
|
ray.get([w.execute.remote(train_fn) for w in workers])
|
||||||
|
elif strategy == "debug":
|
||||||
|
remote_train = ray.remote(setup_and_train)
|
||||||
|
if use_gpu >= 0:
|
||||||
|
msg.info("Enabling GPU with Ray")
|
||||||
|
remote_train = remote_train.options(num_gpus=0.9)
|
||||||
|
ray.get([remote_train.remote(
|
||||||
|
use_gpu,
|
||||||
|
train_args,
|
||||||
|
rank=rank,
|
||||||
|
total_workers=num_workers) for rank in range(num_workers)])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user