spaCy/spacy/cli/ray_utils.py
Richard Liaw 82a3a3b9a7 lint
2020-06-12 19:34:20 -07:00

52 lines
1.6 KiB
Python

import ray
from wasabi import msg
from .. import util
class OptimizerWorker:
def __init__(self, config_path):
msg.info(f"Loading config from: {config_path}")
config = util.load_config(config_path, create_objects=False)
util.fix_random_seed(config["training"]["seed"])
config = util.load_config(config_path, create_objects=True)
training = config["training"]
optimizer = training["optimizer"]
self.optimizer = optimizer
self.weights_dict = {}
def call(self, key, weights, gradient, *, lr_scale=1.0):
if key not in self.weights_dict:
self.weights_dict[key] = weights.copy()
new_weights, new_grads = self.optimizer(
key, self.weights_dict[key], gradient.copy(), lr_scale=lr_scale)
self.weights_dict[key] = new_weights
return new_weights, new_grads
def fetch(self):
return self.optimizer
def step_schedules(self):
self.optimizer.step_schedules()
class RayOptimizer:
local_optimizer = None
def __init__(self, config_path):
RemoteOptimizer = ray.remote(OptimizerWorker)
self.optimizer = RemoteOptimizer.remote(config_path)
self.sync()
def sync(self):
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
def __call__(self, *args, **kwargs):
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
return weights.copy(), grads.copy()
def __getattr__(self, name):
return getattr(self.local_optimizer, name)
def step_schedules(self):
self.optimizer.step_schedules.remote()
self.sync()