mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-07 07:00:34 +03:00
move-ps
This commit is contained in:
parent
d1de4b1ea9
commit
a5a3ed722c
48
spacy/cli/ray_param_server.py
Normal file
48
spacy/cli/ray_param_server.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
"""Parameter Server distributed training with Ray."""
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from wasabi import msg
|
||||||
|
from .. import util
|
||||||
|
|
||||||
|
class OptimizerWorker:
|
||||||
|
def __init__(self, config_path):
|
||||||
|
self.optimizer = _create_optimizer(config_path)
|
||||||
|
self.weights_dict = {}
|
||||||
|
|
||||||
|
def call(self, key, weights, gradient, *, lr_scale=1.0):
|
||||||
|
if key not in self.weights_dict:
|
||||||
|
self.weights_dict[key] = weights.copy()
|
||||||
|
new_weights, new_grads = self.optimizer(
|
||||||
|
key, self.weights_dict[key], gradient.copy(), lr_scale=lr_scale)
|
||||||
|
self.weights_dict[key] = new_weights
|
||||||
|
return new_weights, new_grads
|
||||||
|
|
||||||
|
def fetch(self):
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
|
def step_schedules(self):
|
||||||
|
self.optimizer.step_schedules()
|
||||||
|
|
||||||
|
class RayOptimizer:
|
||||||
|
local_optimizer = None
|
||||||
|
|
||||||
|
def __init__(self, config_path, use_gpu):
|
||||||
|
RemoteOptimizer = ray.remote(OptimizerWorker)
|
||||||
|
if use_gpu >= 0:
|
||||||
|
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
|
||||||
|
self.optimizer = RemoteOptimizer.remote(config_path)
|
||||||
|
self.sync()
|
||||||
|
|
||||||
|
def sync(self):
|
||||||
|
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
|
||||||
|
return weights.copy(), grads.copy()
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.local_optimizer, name)
|
||||||
|
|
||||||
|
def step_schedules(self):
|
||||||
|
self.optimizer.step_schedules.remote()
|
||||||
|
self.sync()
|
|
@ -1,3 +1,5 @@
|
||||||
|
"""Allreduce distributed training with Ray."""
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
from .. import util
|
from .. import util
|
||||||
|
@ -16,49 +18,6 @@ def _create_optimizer(config_path):
|
||||||
training = config["training"]
|
training = config["training"]
|
||||||
return training["optimizer"]
|
return training["optimizer"]
|
||||||
|
|
||||||
class OptimizerWorker:
|
|
||||||
def __init__(self, config_path):
|
|
||||||
self.optimizer = _create_optimizer(config_path)
|
|
||||||
self.weights_dict = {}
|
|
||||||
|
|
||||||
def call(self, key, weights, gradient, *, lr_scale=1.0):
|
|
||||||
if key not in self.weights_dict:
|
|
||||||
self.weights_dict[key] = weights.copy()
|
|
||||||
new_weights, new_grads = self.optimizer(
|
|
||||||
key, self.weights_dict[key], gradient.copy(), lr_scale=lr_scale)
|
|
||||||
self.weights_dict[key] = new_weights
|
|
||||||
return new_weights, new_grads
|
|
||||||
|
|
||||||
def fetch(self):
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
def step_schedules(self):
|
|
||||||
self.optimizer.step_schedules()
|
|
||||||
|
|
||||||
class RayOptimizer:
|
|
||||||
local_optimizer = None
|
|
||||||
|
|
||||||
def __init__(self, config_path, use_gpu):
|
|
||||||
RemoteOptimizer = ray.remote(OptimizerWorker)
|
|
||||||
if use_gpu >= 0:
|
|
||||||
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
|
|
||||||
self.optimizer = RemoteOptimizer.remote(config_path)
|
|
||||||
self.sync()
|
|
||||||
|
|
||||||
def sync(self):
|
|
||||||
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
|
|
||||||
return weights.copy(), grads.copy()
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self.local_optimizer, name)
|
|
||||||
|
|
||||||
def step_schedules(self):
|
|
||||||
self.optimizer.step_schedules.remote()
|
|
||||||
self.sync()
|
|
||||||
|
|
||||||
class RayWorker:
|
class RayWorker:
|
||||||
def __init__(self, rank, world_size):
|
def __init__(self, rank, world_size):
|
||||||
global nccl
|
global nccl
|
||||||
|
|
|
@ -143,7 +143,7 @@ def train_cli(
|
||||||
verbose=False,
|
verbose=False,
|
||||||
use_gpu=-1,
|
use_gpu=-1,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
strategy="ps",
|
strategy="allreduce",
|
||||||
tag_map_path=None,
|
tag_map_path=None,
|
||||||
omit_extra_lookups=False,
|
omit_extra_lookups=False,
|
||||||
):
|
):
|
||||||
|
@ -197,10 +197,10 @@ def train_cli(
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_workers and num_workers > 1:
|
if num_workers and num_workers > 1:
|
||||||
from spacy.cli.ray_utils import RayOptimizer
|
|
||||||
import ray
|
import ray
|
||||||
ray.init()
|
ray.init()
|
||||||
if strategy == "ps":
|
if strategy == "ps":
|
||||||
|
from spacy.cli.ray_param_server import RayOptimizer
|
||||||
remote_train = ray.remote(setup_and_train)
|
remote_train = ray.remote(setup_and_train)
|
||||||
if use_gpu >= 0:
|
if use_gpu >= 0:
|
||||||
msg.info("Enabling GPU with Ray")
|
msg.info("Enabling GPU with Ray")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user