This commit is contained in:
Richard Liaw 2020-06-16 20:33:21 -07:00
parent d1de4b1ea9
commit a5a3ed722c
3 changed files with 52 additions and 45 deletions

View File

@ -0,0 +1,48 @@
"""Parameter Server distributed training with Ray."""
import ray
from wasabi import msg
from .. import util
class OptimizerWorker:
def __init__(self, config_path):
self.optimizer = _create_optimizer(config_path)
self.weights_dict = {}
def call(self, key, weights, gradient, *, lr_scale=1.0):
if key not in self.weights_dict:
self.weights_dict[key] = weights.copy()
new_weights, new_grads = self.optimizer(
key, self.weights_dict[key], gradient.copy(), lr_scale=lr_scale)
self.weights_dict[key] = new_weights
return new_weights, new_grads
def fetch(self):
return self.optimizer
def step_schedules(self):
self.optimizer.step_schedules()
class RayOptimizer:
local_optimizer = None
def __init__(self, config_path, use_gpu):
RemoteOptimizer = ray.remote(OptimizerWorker)
if use_gpu >= 0:
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
self.optimizer = RemoteOptimizer.remote(config_path)
self.sync()
def sync(self):
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
def __call__(self, *args, **kwargs):
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
return weights.copy(), grads.copy()
def __getattr__(self, name):
return getattr(self.local_optimizer, name)
def step_schedules(self):
self.optimizer.step_schedules.remote()
self.sync()

View File

@ -1,3 +1,5 @@
"""Allreduce distributed training with Ray."""
import ray
from wasabi import msg
from .. import util
@ -16,49 +18,6 @@ def _create_optimizer(config_path):
training = config["training"]
return training["optimizer"]
class OptimizerWorker:
def __init__(self, config_path):
self.optimizer = _create_optimizer(config_path)
self.weights_dict = {}
def call(self, key, weights, gradient, *, lr_scale=1.0):
if key not in self.weights_dict:
self.weights_dict[key] = weights.copy()
new_weights, new_grads = self.optimizer(
key, self.weights_dict[key], gradient.copy(), lr_scale=lr_scale)
self.weights_dict[key] = new_weights
return new_weights, new_grads
def fetch(self):
return self.optimizer
def step_schedules(self):
self.optimizer.step_schedules()
class RayOptimizer:
local_optimizer = None
def __init__(self, config_path, use_gpu):
RemoteOptimizer = ray.remote(OptimizerWorker)
if use_gpu >= 0:
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
self.optimizer = RemoteOptimizer.remote(config_path)
self.sync()
def sync(self):
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
def __call__(self, *args, **kwargs):
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
return weights.copy(), grads.copy()
def __getattr__(self, name):
return getattr(self.local_optimizer, name)
def step_schedules(self):
self.optimizer.step_schedules.remote()
self.sync()
class RayWorker:
def __init__(self, rank, world_size):
global nccl

View File

@ -143,7 +143,7 @@ def train_cli(
verbose=False,
use_gpu=-1,
num_workers=1,
strategy="ps",
strategy="allreduce",
tag_map_path=None,
omit_extra_lookups=False,
):
@ -197,10 +197,10 @@ def train_cli(
)
if num_workers and num_workers > 1:
from spacy.cli.ray_utils import RayOptimizer
import ray
ray.init()
if strategy == "ps":
from spacy.cli.ray_param_server import RayOptimizer
remote_train = ray.remote(setup_and_train)
if use_gpu >= 0:
msg.info("Enabling GPU with Ray")