This commit is contained in:
Richard Liaw 2020-06-15 21:59:27 -07:00
parent 54951aa976
commit ef2af90f54
2 changed files with 18 additions and 9 deletions

View File

@ -31,8 +31,10 @@ class OptimizerWorker:
class RayOptimizer: class RayOptimizer:
local_optimizer = None local_optimizer = None
def __init__(self, config_path): def __init__(self, config_path, use_gpu):
RemoteOptimizer = ray.remote(OptimizerWorker) RemoteOptimizer = ray.remote(OptimizerWorker)
if use_gpu >= 0:
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
self.optimizer = RemoteOptimizer.remote(config_path) self.optimizer = RemoteOptimizer.remote(config_path)
self.sync() self.sync()

View File

@ -4,6 +4,7 @@ import math
import srsly import srsly
from pydantic import BaseModel, FilePath from pydantic import BaseModel, FilePath
import plac import plac
import os
import tqdm import tqdm
from pathlib import Path from pathlib import Path
from wasabi import msg from wasabi import msg
@ -198,7 +199,11 @@ def train_cli(
import ray import ray
ray.init() ray.init()
remote_train = ray.remote(setup_and_train) remote_train = ray.remote(setup_and_train)
train_args["remote_optimizer"] = RayOptimizer(config_path) if use_gpu >= 0:
msg.info("Enabling GPU with Ray")
remote_train = remote_train.options(num_gpus=0.9)
train_args["remote_optimizer"] = RayOptimizer(config_path, use_gpu=use_gpu)
ray.get([remote_train.remote( ray.get([remote_train.remote(
use_gpu, use_gpu,
train_args, train_args,
@ -210,17 +215,19 @@ def train_cli(
world_rank = None world_rank = None
world_size = None world_size = None
def setup_and_train(use_gpu, train_args, rank, total_workers): def setup_and_train(use_gpu, train_args, rank=None, total_workers=None):
if use_gpu >= 0: if rank is not None:
msg.info("Using GPU: {use_gpu}")
util.use_gpu(use_gpu)
else:
msg.info("Using CPU")
if rank:
global world_rank global world_rank
world_rank = rank world_rank = rank
global world_size global world_size
world_size = total_workers world_size = total_workers
if use_gpu >= 0:
use_gpu = 0
if use_gpu >= 0:
msg.info(f"Using GPU: {use_gpu}")
util.use_gpu(use_gpu)
else:
msg.info("Using CPU")
train(**train_args) train(**train_args)
def train( def train(