This commit is contained in:
Richard Liaw 2020-06-15 21:59:27 -07:00
parent 54951aa976
commit ef2af90f54
2 changed files with 18 additions and 9 deletions

View File

@ -31,8 +31,10 @@ class OptimizerWorker:
class RayOptimizer:
local_optimizer = None
def __init__(self, config_path):
def __init__(self, config_path, use_gpu):
RemoteOptimizer = ray.remote(OptimizerWorker)
if use_gpu >= 0:
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
self.optimizer = RemoteOptimizer.remote(config_path)
self.sync()

View File

@ -4,6 +4,7 @@ import math
import srsly
from pydantic import BaseModel, FilePath
import plac
import os
import tqdm
from pathlib import Path
from wasabi import msg
@ -198,7 +199,11 @@ def train_cli(
import ray
ray.init()
remote_train = ray.remote(setup_and_train)
train_args["remote_optimizer"] = RayOptimizer(config_path)
if use_gpu >= 0:
msg.info("Enabling GPU with Ray")
remote_train = remote_train.options(num_gpus=0.9)
train_args["remote_optimizer"] = RayOptimizer(config_path, use_gpu=use_gpu)
ray.get([remote_train.remote(
use_gpu,
train_args,
@ -210,17 +215,19 @@ def train_cli(
world_rank = None
world_size = None
def setup_and_train(use_gpu, train_args, rank, total_workers):
if use_gpu >= 0:
msg.info("Using GPU: {use_gpu}")
util.use_gpu(use_gpu)
else:
msg.info("Using CPU")
if rank:
def setup_and_train(use_gpu, train_args, rank=None, total_workers=None):
if rank is not None:
global world_rank
world_rank = rank
global world_size
world_size = total_workers
if use_gpu >= 0:
use_gpu = 0
if use_gpu >= 0:
msg.info(f"Using GPU: {use_gpu}")
util.use_gpu(use_gpu)
else:
msg.info("Using CPU")
train(**train_args)
def train(