mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-31 11:46:22 +03:00
ray-init
This commit is contained in:
parent
a1c5b694be
commit
e6536279d4
51
spacy/cli/ray_utils.py
Normal file
51
spacy/cli/ray_utils.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
import ray
|
||||||
|
from wasabi import msg
|
||||||
|
from .. import util
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerWorker:
|
||||||
|
def __init__(self, config_path):
|
||||||
|
msg.info(f"Loading config from: {config_path}")
|
||||||
|
config = util.load_config(config_path, create_objects=False)
|
||||||
|
util.fix_random_seed(config["training"]["seed"])
|
||||||
|
config = util.load_config(config_path, create_objects=True)
|
||||||
|
training = config["training"]
|
||||||
|
optimizer = training["optimizer"]
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.weights_dict = {}
|
||||||
|
|
||||||
|
def call(self, key, weights, gradient, *, lr_scale=1.0):
|
||||||
|
if key not in self.weights_dict:
|
||||||
|
self.weights_dict[key] = weights.copy()
|
||||||
|
new_weights, new_grads = self.optimizer(
|
||||||
|
key, self.weights_dict[key], gradient.copy(), lr_scale=lr_scale)
|
||||||
|
self.weights_dict[key] = new_weights
|
||||||
|
return new_weights, new_grads
|
||||||
|
|
||||||
|
def fetch(self):
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
|
def step_schedules(self):
|
||||||
|
self.optimizer.step_schedules()
|
||||||
|
|
||||||
|
class RayOptimizer:
|
||||||
|
local_optimizer = None
|
||||||
|
|
||||||
|
def __init__(self, config_path):
|
||||||
|
RemoteOptimizer = ray.remote(OptimizerWorker)
|
||||||
|
self.optimizer = RemoteOptimizer.remote(config_path)
|
||||||
|
self.sync()
|
||||||
|
|
||||||
|
def sync(self):
|
||||||
|
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
|
||||||
|
return weights.copy(), grads.copy()
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.local_optimizer, name)
|
||||||
|
|
||||||
|
def step_schedules(self):
|
||||||
|
self.optimizer.step_schedules.remote()
|
||||||
|
self.sync()
|
|
@ -126,6 +126,7 @@ class ConfigSchema(BaseModel):
|
||||||
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
|
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
|
||||||
verbose=("Display more information for debugging purposes", "flag", "VV", bool),
|
verbose=("Display more information for debugging purposes", "flag", "VV", bool),
|
||||||
use_gpu=("Use GPU", "option", "g", int),
|
use_gpu=("Use GPU", "option", "g", int),
|
||||||
|
num_workers=("Parallel Workers", "option", "j", int),
|
||||||
tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
|
tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
|
||||||
omit_extra_lookups=("Don't include extra lookups in model", "flag", "OEL", bool),
|
omit_extra_lookups=("Don't include extra lookups in model", "flag", "OEL", bool),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
@ -139,6 +140,7 @@ def train_cli(
|
||||||
raw_text=None,
|
raw_text=None,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
use_gpu=-1,
|
use_gpu=-1,
|
||||||
|
num_workers=1,
|
||||||
tag_map_path=None,
|
tag_map_path=None,
|
||||||
omit_extra_lookups=False,
|
omit_extra_lookups=False,
|
||||||
):
|
):
|
||||||
|
@ -181,22 +183,33 @@ def train_cli(
|
||||||
with init_tok2vec.open("rb") as file_:
|
with init_tok2vec.open("rb") as file_:
|
||||||
weights_data = file_.read()
|
weights_data = file_.read()
|
||||||
|
|
||||||
|
train_args = dict(
|
||||||
|
config_path=config_path,
|
||||||
|
data_paths={"train": train_path, "dev": dev_path},
|
||||||
|
output_path=output_path,
|
||||||
|
raw_text=raw_text,
|
||||||
|
tag_map=tag_map,
|
||||||
|
weights_data=weights_data,
|
||||||
|
omit_extra_lookups=omit_extra_lookups
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_workers and num_workers > 1:
|
||||||
|
from spacy.cli.ray_utils import RayOptimizer
|
||||||
|
import ray
|
||||||
|
ray.init()
|
||||||
|
remote_train = ray.remote(setup_and_train)
|
||||||
|
train_args["remote_optimizer"] = RayOptimizer(config_path)
|
||||||
|
ray.get([remote_train.remote(use_gpu, train_args) for _ in range(num_workers)])
|
||||||
|
else:
|
||||||
|
setup_and_train(use_gpu, train_args)
|
||||||
|
|
||||||
|
def setup_and_train(use_gpu, train_args):
|
||||||
if use_gpu >= 0:
|
if use_gpu >= 0:
|
||||||
msg.info("Using GPU: {use_gpu}")
|
msg.info("Using GPU: {use_gpu}")
|
||||||
util.use_gpu(use_gpu)
|
util.use_gpu(use_gpu)
|
||||||
else:
|
else:
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
|
train(**train_args)
|
||||||
train(
|
|
||||||
config_path,
|
|
||||||
{"train": train_path, "dev": dev_path},
|
|
||||||
output_path=output_path,
|
|
||||||
raw_text=raw_text,
|
|
||||||
tag_map=tag_map,
|
|
||||||
weights_data=weights_data,
|
|
||||||
omit_extra_lookups=omit_extra_lookups,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
config_path,
|
config_path,
|
||||||
|
@ -206,6 +219,7 @@ def train(
|
||||||
tag_map=None,
|
tag_map=None,
|
||||||
weights_data=None,
|
weights_data=None,
|
||||||
omit_extra_lookups=False,
|
omit_extra_lookups=False,
|
||||||
|
remote_optimizer=None
|
||||||
):
|
):
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
# Read the config first without creating objects, to get to the original nlp_config
|
# Read the config first without creating objects, to get to the original nlp_config
|
||||||
|
@ -220,6 +234,8 @@ def train(
|
||||||
msg.info("Creating nlp from config")
|
msg.info("Creating nlp from config")
|
||||||
nlp = util.load_model_from_config(nlp_config)
|
nlp = util.load_model_from_config(nlp_config)
|
||||||
optimizer = training["optimizer"]
|
optimizer = training["optimizer"]
|
||||||
|
if remote_optimizer:
|
||||||
|
optimizer = remote_optimizer
|
||||||
limit = training["limit"]
|
limit = training["limit"]
|
||||||
msg.info("Loading training corpus")
|
msg.info("Loading training corpus")
|
||||||
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
|
||||||
|
@ -332,7 +348,6 @@ def train(
|
||||||
eval_frequency=training["eval_frequency"],
|
eval_frequency=training["eval_frequency"],
|
||||||
raw_text=raw_text,
|
raw_text=raw_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||||
print_row = setup_printer(training, nlp)
|
print_row = setup_printer(training, nlp)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user