spaCy/spacy/cli/ray_utils.py
Richard Liaw a5a3ed722c move-ps
2020-06-16 20:33:21 -07:00

79 lines
2.2 KiB
Python

"""Allreduce distributed training with Ray."""
import ray
from wasabi import msg
from .. import util
cp = None
nccl = None
from typing import Dict, Optional, Union, Tuple, List, cast
from thinc.types import FloatsXd
def _create_optimizer(config_path):
msg.info(f"Loading config from: {config_path}")
config = util.load_config(config_path, create_objects=False)
util.fix_random_seed(config["training"]["seed"]) # Fix this.
config = util.load_config(config_path, create_objects=True)
training = config["training"]
return training["optimizer"]
class RayWorker:
def __init__(self, rank, world_size):
global nccl
from cupy.cuda import nccl
self.rank = rank
self.world_size = world_size
self.unique_id = nccl.get_unique_id()
def initialize(self, head_id):
self.communicator = nccl.NcclCommunicator(self.world_size, head_id, self.rank)
def get_unique_id(self):
return self.unique_id
def execute(self, fn):
return fn(self)
class AllreduceOptimizer:
def __init__(self, config_path, communicator):
global cp
import cupy as cp
global nccl
from cupy.cuda import nccl
self.optimizer = _create_optimizer(config_path)
self.communicator = communicator
self.weights_synced = set()
def allreduce(self, tensor):
self.communicator.allReduce(
tensor.data.ptr,
tensor.data.ptr,
tensor.size,
nccl.NCCL_FLOAT32,
nccl.NCCL_SUM, # TODO: is this a sum?
cp.cuda.Stream.null.ptr
)
return tensor
def __call__(
self,
key: Tuple[int, str],
weights: FloatsXd,
gradient: FloatsXd,
*,
lr_scale: float = 1.0,
):
if key not in self.weights_synced:
self.weights_synced.add(key)
weights = self.allreduce(weights) / self.communicator.size()
gradient = self.allreduce(gradient)
flat_weights, gradient = self.optimizer(key, weights, gradient, lr_scale=lr_scale)
return flat_weights, gradient
def __getattr__(self, name):
return getattr(self.optimizer, name)