mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-31 11:46:22 +03:00
sync
This commit is contained in:
parent
fdc9242bc1
commit
26c975ec66
|
@ -1,23 +1,44 @@
|
||||||
"""Parameter Server distributed training with Ray."""
|
"""Parameter Server distributed training with Ray."""
|
||||||
|
import threading
|
||||||
import ray
|
import ray
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
from .. import util
|
from .. import util
|
||||||
|
from spacy.cli.ray_utils import create_optimizer
|
||||||
|
|
||||||
class OptimizerWorker:
|
class OptimizerWorker:
|
||||||
def __init__(self, config_path, world_size, sync=True):
|
def __init__(self, config_path, world_size):
|
||||||
self.optimizer = _create_optimizer(config_path)
|
self.optimizer = create_optimizer(config_path)
|
||||||
|
self.new_weights = None
|
||||||
|
self.barrier = threading.Barrier(world_size)
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.waiting = 0
|
||||||
self.weights_dict = {}
|
self.weights_dict = {}
|
||||||
|
self.grad_dict = {}
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.sync = sync
|
|
||||||
|
|
||||||
def call(self, rank, key, weights, gradient, *, lr_scale=1.0):
|
def call(self, key, weights, gradient, *, lr_scale=1.0):
|
||||||
if key not in self.weights_dict:
|
self.lock.acquire()
|
||||||
|
|
||||||
|
if self.waiting < self.world_size - 1:
|
||||||
|
if self.waiting == 0:
|
||||||
|
self.gradient[key] = gradient.copy()
|
||||||
self.weights_dict[key] = weights.copy()
|
self.weights_dict[key] = weights.copy()
|
||||||
|
else:
|
||||||
|
self.gradient[key] += gradient
|
||||||
|
self.waiting = self.barrier.n_waiting + 1
|
||||||
|
self.lock.release()
|
||||||
|
self.barrier.wait()
|
||||||
|
else:
|
||||||
|
self.gradient[key] += gradient
|
||||||
|
self.lock.release()
|
||||||
|
self.gradient[key] /= self.world_size
|
||||||
new_weights, new_grads = self.optimizer(
|
new_weights, new_grads = self.optimizer(
|
||||||
key, self.weights_dict[key], gradient.copy(), lr_scale=lr_scale)
|
key, self.weights_dict[key], self.gradient[key], lr_scale=lr_scale)
|
||||||
self.weights_dict[key] = new_weights
|
self.weights_dict[key] = new_weights
|
||||||
return new_weights, new_grads
|
self.gradient[key] = new_grads
|
||||||
|
self.waiting = 0
|
||||||
|
self.barrier.wait()
|
||||||
|
return self.weights_dict[key], self.gradient[key]
|
||||||
|
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
return self.optimizer
|
return self.optimizer
|
||||||
|
@ -40,7 +61,7 @@ class RayOptimizer:
|
||||||
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
|
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
weights, grads = ray.get(self.optimizer.call.remote(self.rank, *args, **kwargs))
|
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
|
||||||
return weights.copy(), grads.copy()
|
return weights.copy(), grads.copy()
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
|
|
|
@ -10,7 +10,7 @@ nccl = None
|
||||||
from typing import Dict, Optional, Union, Tuple, List, cast
|
from typing import Dict, Optional, Union, Tuple, List, cast
|
||||||
from thinc.types import FloatsXd
|
from thinc.types import FloatsXd
|
||||||
|
|
||||||
def _create_optimizer(config_path):
|
def create_optimizer(config_path):
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
config = util.load_config(config_path, create_objects=False)
|
config = util.load_config(config_path, create_objects=False)
|
||||||
util.fix_random_seed(config["training"]["seed"]) # Fix this.
|
util.fix_random_seed(config["training"]["seed"]) # Fix this.
|
||||||
|
@ -41,7 +41,7 @@ class AllreduceOptimizer:
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
global nccl
|
global nccl
|
||||||
from cupy.cuda import nccl
|
from cupy.cuda import nccl
|
||||||
self.optimizer = _create_optimizer(config_path)
|
self.optimizer = create_optimizer(config_path)
|
||||||
self.communicator = communicator
|
self.communicator = communicator
|
||||||
self.weights_synced = set()
|
self.weights_synced = set()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user