This commit is contained in:
Richard Liaw 2020-06-17 21:51:22 -07:00
parent fdc9242bc1
commit 26c975ec66
2 changed files with 35 additions and 14 deletions

View File

@ -1,23 +1,44 @@
"""Parameter Server distributed training with Ray.""" """Parameter Server distributed training with Ray."""
import threading
import ray import ray
from wasabi import msg from wasabi import msg
from .. import util from .. import util
from spacy.cli.ray_utils import create_optimizer
class OptimizerWorker: class OptimizerWorker:
def __init__(self, config_path, world_size, sync=True): def __init__(self, config_path, world_size):
self.optimizer = _create_optimizer(config_path) self.optimizer = create_optimizer(config_path)
self.new_weights = None
self.barrier = threading.Barrier(world_size)
self.lock = threading.Lock()
self.waiting = 0
self.weights_dict = {} self.weights_dict = {}
self.grad_dict = {}
self.world_size = world_size self.world_size = world_size
self.sync = sync
def call(self, rank, key, weights, gradient, *, lr_scale=1.0): def call(self, key, weights, gradient, *, lr_scale=1.0):
if key not in self.weights_dict: self.lock.acquire()
if self.waiting < self.world_size - 1:
if self.waiting == 0:
self.gradient[key] = gradient.copy()
self.weights_dict[key] = weights.copy() self.weights_dict[key] = weights.copy()
else:
self.gradient[key] += gradient
self.waiting = self.barrier.n_waiting + 1
self.lock.release()
self.barrier.wait()
else:
self.gradient[key] += gradient
self.lock.release()
self.gradient[key] /= self.world_size
new_weights, new_grads = self.optimizer( new_weights, new_grads = self.optimizer(
key, self.weights_dict[key], gradient.copy(), lr_scale=lr_scale) key, self.weights_dict[key], self.gradient[key], lr_scale=lr_scale)
self.weights_dict[key] = new_weights self.weights_dict[key] = new_weights
return new_weights, new_grads self.gradient[key] = new_grads
self.waiting = 0
self.barrier.wait()
return self.weights_dict[key], self.gradient[key]
def fetch(self): def fetch(self):
return self.optimizer return self.optimizer
@ -40,7 +61,7 @@ class RayOptimizer:
self.local_optimizer = ray.get(self.optimizer.fetch.remote()) self.local_optimizer = ray.get(self.optimizer.fetch.remote())
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
weights, grads = ray.get(self.optimizer.call.remote(self.rank, *args, **kwargs)) weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
return weights.copy(), grads.copy() return weights.copy(), grads.copy()
def __getattr__(self, name): def __getattr__(self, name):

View File

@ -10,7 +10,7 @@ nccl = None
from typing import Dict, Optional, Union, Tuple, List, cast from typing import Dict, Optional, Union, Tuple, List, cast
from thinc.types import FloatsXd from thinc.types import FloatsXd
def _create_optimizer(config_path): def create_optimizer(config_path):
msg.info(f"Loading config from: {config_path}") msg.info(f"Loading config from: {config_path}")
config = util.load_config(config_path, create_objects=False) config = util.load_config(config_path, create_objects=False)
util.fix_random_seed(config["training"]["seed"]) # Fix this. util.fix_random_seed(config["training"]["seed"]) # Fix this.
@ -41,7 +41,7 @@ class AllreduceOptimizer:
import cupy as cp import cupy as cp
global nccl global nccl
from cupy.cuda import nccl from cupy.cuda import nccl
self.optimizer = _create_optimizer(config_path) self.optimizer = create_optimizer(config_path)
self.communicator = communicator self.communicator = communicator
self.weights_synced = set() self.weights_synced = set()