mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-31 03:34:07 +03:00
ray-init
This commit is contained in:
parent
a1c5b694be
commit
e6536279d4
51
spacy/cli/ray_utils.py
Normal file
51
spacy/cli/ray_utils.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import ray
|
||||
from wasabi import msg
|
||||
from .. import util
|
||||
|
||||
|
||||
class OptimizerWorker:
|
||||
def __init__(self, config_path):
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
config = util.load_config(config_path, create_objects=False)
|
||||
util.fix_random_seed(config["training"]["seed"])
|
||||
config = util.load_config(config_path, create_objects=True)
|
||||
training = config["training"]
|
||||
optimizer = training["optimizer"]
|
||||
self.optimizer = optimizer
|
||||
self.weights_dict = {}
|
||||
|
||||
def call(self, key, weights, gradient, *, lr_scale=1.0):
|
||||
if key not in self.weights_dict:
|
||||
self.weights_dict[key] = weights.copy()
|
||||
new_weights, new_grads = self.optimizer(
|
||||
key, self.weights_dict[key], gradient.copy(), lr_scale=lr_scale)
|
||||
self.weights_dict[key] = new_weights
|
||||
return new_weights, new_grads
|
||||
|
||||
def fetch(self):
|
||||
return self.optimizer
|
||||
|
||||
def step_schedules(self):
|
||||
self.optimizer.step_schedules()
|
||||
|
||||
class RayOptimizer:
|
||||
local_optimizer = None
|
||||
|
||||
def __init__(self, config_path):
|
||||
RemoteOptimizer = ray.remote(OptimizerWorker)
|
||||
self.optimizer = RemoteOptimizer.remote(config_path)
|
||||
self.sync()
|
||||
|
||||
def sync(self):
|
||||
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
|
||||
return weights.copy(), grads.copy()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.local_optimizer, name)
|
||||
|
||||
def step_schedules(self):
|
||||
self.optimizer.step_schedules.remote()
|
||||
self.sync()
|
|
@ -126,6 +126,7 @@ class ConfigSchema(BaseModel):
|
|||
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
|
||||
verbose=("Display more information for debugging purposes", "flag", "VV", bool),
|
||||
use_gpu=("Use GPU", "option", "g", int),
|
||||
num_workers=("Parallel Workers", "option", "j", int),
|
||||
tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
|
||||
omit_extra_lookups=("Don't include extra lookups in model", "flag", "OEL", bool),
|
||||
# fmt: on
|
||||
|
@ -139,6 +140,7 @@ def train_cli(
|
|||
raw_text=None,
|
||||
verbose=False,
|
||||
use_gpu=-1,
|
||||
num_workers=1,
|
||||
tag_map_path=None,
|
||||
omit_extra_lookups=False,
|
||||
):
|
||||
|
@ -181,22 +183,33 @@ def train_cli(
|
|||
with init_tok2vec.open("rb") as file_:
|
||||
weights_data = file_.read()
|
||||
|
||||
train_args = dict(
|
||||
config_path=config_path,
|
||||
data_paths={"train": train_path, "dev": dev_path},
|
||||
output_path=output_path,
|
||||
raw_text=raw_text,
|
||||
tag_map=tag_map,
|
||||
weights_data=weights_data,
|
||||
omit_extra_lookups=omit_extra_lookups
|
||||
)
|
||||
|
||||
if num_workers and num_workers > 1:
|
||||
from spacy.cli.ray_utils import RayOptimizer
|
||||
import ray
|
||||
ray.init()
|
||||
remote_train = ray.remote(setup_and_train)
|
||||
train_args["remote_optimizer"] = RayOptimizer(config_path)
|
||||
ray.get([remote_train.remote(use_gpu, train_args) for _ in range(num_workers)])
|
||||
else:
|
||||
setup_and_train(use_gpu, train_args)
|
||||
|
||||
def setup_and_train(use_gpu, train_args):
|
||||
if use_gpu >= 0:
|
||||
msg.info("Using GPU: {use_gpu}")
|
||||
util.use_gpu(use_gpu)
|
||||
else:
|
||||
msg.info("Using CPU")
|
||||
|
||||
train(
|
||||
config_path,
|
||||
{"train": train_path, "dev": dev_path},
|
||||
output_path=output_path,
|
||||
raw_text=raw_text,
|
||||
tag_map=tag_map,
|
||||
weights_data=weights_data,
|
||||
omit_extra_lookups=omit_extra_lookups,
|
||||
)
|
||||
|
||||
train(**train_args)
|
||||
|
||||
def train(
|
||||
config_path,
|
||||
|
@ -206,6 +219,7 @@ def train(
|
|||
tag_map=None,
|
||||
weights_data=None,
|
||||
omit_extra_lookups=False,
|
||||
remote_optimizer=None
|
||||
):
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
# Read the config first without creating objects, to get to the original nlp_config
|
||||
|
@ -220,6 +234,8 @@ def train(
|
|||
msg.info("Creating nlp from config")
|
||||
nlp = util.load_model_from_config(nlp_config)
|
||||
optimizer = training["optimizer"]
|
||||
if remote_optimizer:
|
||||
optimizer = remote_optimizer
|
||||
limit = training["limit"]
|
||||
msg.info("Loading training corpus")
|
||||
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
||||
|
@ -332,7 +348,6 @@ def train(
|
|||
eval_frequency=training["eval_frequency"],
|
||||
raw_text=raw_text,
|
||||
)
|
||||
|
||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||
print_row = setup_printer(training, nlp)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user