spaCy/spacy/cli/ray_param_server.py

73 lines
2.4 KiB
Python
Raw Normal View History

2020-06-17 06:33:21 +03:00
"""Parameter Server distributed training with Ray."""
2020-06-18 07:51:22 +03:00
import threading
2020-06-17 06:33:21 +03:00
import ray
from wasabi import msg
from .. import util
2020-06-18 07:51:22 +03:00
from spacy.cli.ray_utils import create_optimizer
2020-06-17 06:33:21 +03:00
class OptimizerWorker:
2020-06-18 07:51:22 +03:00
def __init__(self, config_path, world_size):
self.optimizer = create_optimizer(config_path)
self.new_weights = None
self.barrier = threading.Barrier(world_size)
self.lock = threading.Lock()
self.waiting = 0
2020-06-17 06:33:21 +03:00
self.weights_dict = {}
2020-06-18 07:51:22 +03:00
self.grad_dict = {}
2020-06-18 05:42:53 +03:00
self.world_size = world_size
2020-06-17 06:33:21 +03:00
2020-06-18 07:51:22 +03:00
def call(self, key, weights, gradient, *, lr_scale=1.0):
self.lock.acquire()
if self.waiting < self.world_size - 1:
if self.waiting == 0:
self.gradient[key] = gradient.copy()
self.weights_dict[key] = weights.copy()
else:
self.gradient[key] += gradient
self.waiting = self.barrier.n_waiting + 1
self.lock.release()
self.barrier.wait()
else:
self.gradient[key] += gradient
self.lock.release()
self.gradient[key] /= self.world_size
new_weights, new_grads = self.optimizer(
key, self.weights_dict[key], self.gradient[key], lr_scale=lr_scale)
self.weights_dict[key] = new_weights
self.gradient[key] = new_grads
self.waiting = 0
self.barrier.wait()
return self.weights_dict[key], self.gradient[key]
2020-06-17 06:33:21 +03:00
def fetch(self):
return self.optimizer
def step_schedules(self):
self.optimizer.step_schedules()
class RayOptimizer:
local_optimizer = None
2020-06-18 05:42:53 +03:00
def __init__(self, config_path, use_gpu, rank):
2020-06-17 06:33:21 +03:00
RemoteOptimizer = ray.remote(OptimizerWorker)
if use_gpu >= 0:
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
self.optimizer = RemoteOptimizer.remote(config_path)
2020-06-18 05:42:53 +03:00
self.rank = rank
2020-06-17 06:33:21 +03:00
self.sync()
def sync(self):
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
def __call__(self, *args, **kwargs):
2020-06-18 07:51:22 +03:00
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
2020-06-17 06:33:21 +03:00
return weights.copy(), grads.copy()
def __getattr__(self, name):
return getattr(self.local_optimizer, name)
def step_schedules(self):
self.optimizer.step_schedules.remote()
self.sync()