This commit is contained in:
Richard Liaw 2020-06-12 19:12:17 -07:00
parent a1c5b694be
commit e6536279d4
2 changed files with 78 additions and 12 deletions

51
spacy/cli/ray_utils.py Normal file
View File

@ -0,0 +1,51 @@
import ray
from wasabi import msg
from .. import util
class OptimizerWorker:
def __init__(self, config_path):
msg.info(f"Loading config from: {config_path}")
config = util.load_config(config_path, create_objects=False)
util.fix_random_seed(config["training"]["seed"])
config = util.load_config(config_path, create_objects=True)
training = config["training"]
optimizer = training["optimizer"]
self.optimizer = optimizer
self.weights_dict = {}
def call(self, key, weights, gradient, *, lr_scale=1.0):
if key not in self.weights_dict:
self.weights_dict[key] = weights.copy()
new_weights, new_grads = self.optimizer(
key, self.weights_dict[key], gradient.copy(), lr_scale=lr_scale)
self.weights_dict[key] = new_weights
return new_weights, new_grads
def fetch(self):
return self.optimizer
def step_schedules(self):
self.optimizer.step_schedules()
class RayOptimizer:
local_optimizer = None
def __init__(self, config_path):
RemoteOptimizer = ray.remote(OptimizerWorker)
self.optimizer = RemoteOptimizer.remote(config_path)
self.sync()
def sync(self):
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
def __call__(self, *args, **kwargs):
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
return weights.copy(), grads.copy()
def __getattr__(self, name):
return getattr(self.local_optimizer, name)
def step_schedules(self):
self.optimizer.step_schedules.remote()
self.sync()

View File

@ -126,6 +126,7 @@ class ConfigSchema(BaseModel):
raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path), raw_text=("Path to jsonl file with unlabelled text documents.", "option", "rt", Path),
verbose=("Display more information for debugging purposes", "flag", "VV", bool), verbose=("Display more information for debugging purposes", "flag", "VV", bool),
use_gpu=("Use GPU", "option", "g", int), use_gpu=("Use GPU", "option", "g", int),
num_workers=("Parallel Workers", "option", "j", int),
tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path), tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
omit_extra_lookups=("Don't include extra lookups in model", "flag", "OEL", bool), omit_extra_lookups=("Don't include extra lookups in model", "flag", "OEL", bool),
# fmt: on # fmt: on
@ -139,6 +140,7 @@ def train_cli(
raw_text=None, raw_text=None,
verbose=False, verbose=False,
use_gpu=-1, use_gpu=-1,
num_workers=1,
tag_map_path=None, tag_map_path=None,
omit_extra_lookups=False, omit_extra_lookups=False,
): ):
@ -181,22 +183,33 @@ def train_cli(
with init_tok2vec.open("rb") as file_: with init_tok2vec.open("rb") as file_:
weights_data = file_.read() weights_data = file_.read()
train_args = dict(
config_path=config_path,
data_paths={"train": train_path, "dev": dev_path},
output_path=output_path,
raw_text=raw_text,
tag_map=tag_map,
weights_data=weights_data,
omit_extra_lookups=omit_extra_lookups
)
if num_workers and num_workers > 1:
from spacy.cli.ray_utils import RayOptimizer
import ray
ray.init()
remote_train = ray.remote(setup_and_train)
train_args["remote_optimizer"] = RayOptimizer(config_path)
ray.get([remote_train.remote(use_gpu, train_args) for _ in range(num_workers)])
else:
setup_and_train(use_gpu, train_args)
def setup_and_train(use_gpu, train_args):
if use_gpu >= 0: if use_gpu >= 0:
msg.info("Using GPU: {use_gpu}") msg.info("Using GPU: {use_gpu}")
util.use_gpu(use_gpu) util.use_gpu(use_gpu)
else: else:
msg.info("Using CPU") msg.info("Using CPU")
train(**train_args)
train(
config_path,
{"train": train_path, "dev": dev_path},
output_path=output_path,
raw_text=raw_text,
tag_map=tag_map,
weights_data=weights_data,
omit_extra_lookups=omit_extra_lookups,
)
def train( def train(
config_path, config_path,
@ -206,6 +219,7 @@ def train(
tag_map=None, tag_map=None,
weights_data=None, weights_data=None,
omit_extra_lookups=False, omit_extra_lookups=False,
remote_optimizer=None
): ):
msg.info(f"Loading config from: {config_path}") msg.info(f"Loading config from: {config_path}")
# Read the config first without creating objects, to get to the original nlp_config # Read the config first without creating objects, to get to the original nlp_config
@ -220,6 +234,8 @@ def train(
msg.info("Creating nlp from config") msg.info("Creating nlp from config")
nlp = util.load_model_from_config(nlp_config) nlp = util.load_model_from_config(nlp_config)
optimizer = training["optimizer"] optimizer = training["optimizer"]
if remote_optimizer:
optimizer = remote_optimizer
limit = training["limit"] limit = training["limit"]
msg.info("Loading training corpus") msg.info("Loading training corpus")
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit) corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
@ -332,7 +348,6 @@ def train(
eval_frequency=training["eval_frequency"], eval_frequency=training["eval_frequency"],
raw_text=raw_text, raw_text=raw_text,
) )
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}") msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
print_row = setup_printer(training, nlp) print_row = setup_printer(training, nlp)