2020-06-17 06:33:21 +03:00
|
|
|
"""Parameter Server distributed training with Ray."""
|
2020-06-18 07:51:22 +03:00
|
|
|
import threading
|
2020-06-17 06:33:21 +03:00
|
|
|
import ray
|
|
|
|
from wasabi import msg
|
|
|
|
from .. import util
|
2020-06-18 07:51:22 +03:00
|
|
|
from spacy.cli.ray_utils import create_optimizer
|
2020-06-17 06:33:21 +03:00
|
|
|
|
|
|
|
class OptimizerWorker:
|
2020-06-18 07:51:22 +03:00
|
|
|
def __init__(self, config_path, world_size):
|
|
|
|
self.optimizer = create_optimizer(config_path)
|
|
|
|
self.new_weights = None
|
|
|
|
self.barrier = threading.Barrier(world_size)
|
|
|
|
self.lock = threading.Lock()
|
|
|
|
self.waiting = 0
|
2020-06-17 06:33:21 +03:00
|
|
|
self.weights_dict = {}
|
2020-06-18 07:51:22 +03:00
|
|
|
self.grad_dict = {}
|
2020-06-18 05:42:53 +03:00
|
|
|
self.world_size = world_size
|
2020-06-17 06:33:21 +03:00
|
|
|
|
2020-06-18 07:51:22 +03:00
|
|
|
def call(self, key, weights, gradient, *, lr_scale=1.0):
|
|
|
|
self.lock.acquire()
|
|
|
|
|
|
|
|
if self.waiting < self.world_size - 1:
|
|
|
|
if self.waiting == 0:
|
|
|
|
self.gradient[key] = gradient.copy()
|
|
|
|
self.weights_dict[key] = weights.copy()
|
|
|
|
else:
|
|
|
|
self.gradient[key] += gradient
|
|
|
|
self.waiting = self.barrier.n_waiting + 1
|
|
|
|
self.lock.release()
|
|
|
|
self.barrier.wait()
|
|
|
|
else:
|
|
|
|
self.gradient[key] += gradient
|
|
|
|
self.lock.release()
|
|
|
|
self.gradient[key] /= self.world_size
|
|
|
|
new_weights, new_grads = self.optimizer(
|
|
|
|
key, self.weights_dict[key], self.gradient[key], lr_scale=lr_scale)
|
|
|
|
self.weights_dict[key] = new_weights
|
|
|
|
self.gradient[key] = new_grads
|
|
|
|
self.waiting = 0
|
|
|
|
self.barrier.wait()
|
|
|
|
return self.weights_dict[key], self.gradient[key]
|
2020-06-17 06:33:21 +03:00
|
|
|
|
|
|
|
def fetch(self):
|
|
|
|
return self.optimizer
|
|
|
|
|
|
|
|
def step_schedules(self):
|
|
|
|
self.optimizer.step_schedules()
|
|
|
|
|
|
|
|
class RayOptimizer:
|
|
|
|
local_optimizer = None
|
|
|
|
|
2020-06-18 05:42:53 +03:00
|
|
|
def __init__(self, config_path, use_gpu, rank):
|
2020-06-17 06:33:21 +03:00
|
|
|
RemoteOptimizer = ray.remote(OptimizerWorker)
|
|
|
|
if use_gpu >= 0:
|
|
|
|
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
|
|
|
|
self.optimizer = RemoteOptimizer.remote(config_path)
|
2020-06-18 05:42:53 +03:00
|
|
|
self.rank = rank
|
2020-06-17 06:33:21 +03:00
|
|
|
self.sync()
|
|
|
|
|
|
|
|
def sync(self):
|
|
|
|
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
2020-06-18 07:51:22 +03:00
|
|
|
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
|
2020-06-17 06:33:21 +03:00
|
|
|
return weights.copy(), grads.copy()
|
|
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
|
return getattr(self.local_optimizer, name)
|
|
|
|
|
|
|
|
def step_schedules(self):
|
|
|
|
self.optimizer.step_schedules.remote()
|
|
|
|
self.sync()
|