small-changes

This commit is contained in:
Richard Liaw 2020-06-17 19:42:53 -07:00
parent a5a3ed722c
commit fdc9242bc1
3 changed files with 31 additions and 18 deletions

View File

@ -5,11 +5,13 @@ from wasabi import msg
from .. import util
class OptimizerWorker:
def __init__(self, config_path):
def __init__(self, config_path, world_size, sync=True):
self.optimizer = _create_optimizer(config_path)
self.weights_dict = {}
self.world_size = world_size
self.sync = sync
def call(self, key, weights, gradient, *, lr_scale=1.0):
def call(self, rank, key, weights, gradient, *, lr_scale=1.0):
if key not in self.weights_dict:
self.weights_dict[key] = weights.copy()
new_weights, new_grads = self.optimizer(
@ -26,18 +28,19 @@ class OptimizerWorker:
class RayOptimizer:
local_optimizer = None
def __init__(self, config_path, use_gpu):
def __init__(self, config_path, use_gpu, rank):
RemoteOptimizer = ray.remote(OptimizerWorker)
if use_gpu >= 0:
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
self.optimizer = RemoteOptimizer.remote(config_path)
self.rank = rank
self.sync()
def sync(self):
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
def __call__(self, *args, **kwargs):
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
weights, grads = ray.get(self.optimizer.call.remote(self.rank, *args, **kwargs))
return weights.copy(), grads.copy()
def __getattr__(self, name):

View File

@ -69,7 +69,7 @@ class AllreduceOptimizer:
weights = self.allreduce(weights) / self.communicator.size()
gradient = self.allreduce(gradient)
gradient = self.allreduce(gradient) / self.communicator.size()
flat_weights, gradient = self.optimizer(key, weights, gradient, lr_scale=lr_scale)
return flat_weights, gradient

View File

@ -198,7 +198,7 @@ def train_cli(
if num_workers and num_workers > 1:
import ray
ray.init()
ray.init(address="auto")
if strategy == "ps":
from spacy.cli.ray_param_server import RayOptimizer
remote_train = ray.remote(setup_and_train)
@ -401,8 +401,12 @@ def train(
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
print_row = setup_printer(training, nlp)
tqdm_args = dict(total=training["eval_frequency"], leave=False)
global world_rank
if world_rank is not None:
tqdm_args["disable"] = bool(world_rank != 0)
try:
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
progress = tqdm.tqdm(**tqdm_args)
for batch, info, is_best_checkpoint in training_step_iterator:
progress.update(1)
if is_best_checkpoint is not None:
@ -411,7 +415,7 @@ def train(
if is_best_checkpoint and output_path is not None:
update_meta(training, nlp, info)
nlp.to_disk(output_path / "model-best")
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
progress = tqdm.tqdm(**tqdm_args)
# Clean up the objects to faciliate garbage collection.
for eg in batch:
eg.doc = None
@ -437,6 +441,10 @@ def train(
def create_train_batches(nlp, corpus, cfg):
epochs_todo = cfg.get("max_epochs", 0)
if world_rank is not None:
for i in range(world_rank):
# Increment random seed
random.random()
while True:
train_examples = list(
corpus.train_dataset(
@ -452,16 +460,18 @@ def create_train_batches(nlp, corpus, cfg):
raise ValueError(Errors.E988)
random.shuffle(train_examples)
if world_size is not None:
# Taken from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
num_samples = int(math.ceil(len(train_examples) * 1.0 / world_size))
total_size = num_samples * world_size # expected to overflow
train_examples += train_examples[:(total_size - len(train_examples))]
assert len(train_examples) == total_size
# # TODO: with large batches, this can be bad.
# if world_size is not None:
# # Taken from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
# num_samples = int(math.ceil(len(train_examples) * 1.0 / world_size))
# total_size = num_samples * world_size # expected to overflow
# train_examples += train_examples[:(total_size - len(train_examples))]
# assert len(train_examples) == total_size
# subsample
train_examples = train_examples[world_rank:total_size:world_size]
assert len(train_examples) == num_samples
# # subsample
# train_examples = train_examples[world_rank:total_size:world_size]
# assert len(train_examples) == num_samples
# print(f"Reset epoch: Only using {num_samples} out of {total_size} samples")
batches = util.minibatch_by_words(
train_examples,
@ -474,7 +484,7 @@ def create_train_batches(nlp, corpus, cfg):
yield first
except StopIteration:
raise ValueError(Errors.E986)
for batch in batches:
for i, batch in enumerate(batches):
yield batch
epochs_todo -= 1
# We intentionally compare exactly to 0 here, so that max_epochs < 1