This commit is contained in:
Richard Liaw 2020-06-16 19:12:48 -07:00
parent d8705e1291
commit d1de4b1ea9

View File

@ -224,7 +224,18 @@ def train_cli(
optimizer = AllreduceOptimizer(config_path, worker.communicator)
train_args["remote_optimizer"] = optimizer
return setup_and_train(True, train_args, worker.rank, worker.world_size)
ray.get([w.execute.remote(train_fn) for w in workers])
elif strategy == "debug":
remote_train = ray.remote(setup_and_train)
if use_gpu >= 0:
msg.info("Enabling GPU with Ray")
remote_train = remote_train.options(num_gpus=0.9)
ray.get([remote_train.remote(
use_gpu,
train_args,
rank=rank,
total_workers=num_workers) for rank in range(num_workers)])
else:
raise NotImplementedError