mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-07 07:00:34 +03:00
small-changes
This commit is contained in:
parent
a5a3ed722c
commit
fdc9242bc1
|
@ -5,11 +5,13 @@ from wasabi import msg
|
|||
from .. import util
|
||||
|
||||
class OptimizerWorker:
|
||||
def __init__(self, config_path):
|
||||
def __init__(self, config_path, world_size, sync=True):
|
||||
self.optimizer = _create_optimizer(config_path)
|
||||
self.weights_dict = {}
|
||||
self.world_size = world_size
|
||||
self.sync = sync
|
||||
|
||||
def call(self, key, weights, gradient, *, lr_scale=1.0):
|
||||
def call(self, rank, key, weights, gradient, *, lr_scale=1.0):
|
||||
if key not in self.weights_dict:
|
||||
self.weights_dict[key] = weights.copy()
|
||||
new_weights, new_grads = self.optimizer(
|
||||
|
@ -26,18 +28,19 @@ class OptimizerWorker:
|
|||
class RayOptimizer:
|
||||
local_optimizer = None
|
||||
|
||||
def __init__(self, config_path, use_gpu):
|
||||
def __init__(self, config_path, use_gpu, rank):
|
||||
RemoteOptimizer = ray.remote(OptimizerWorker)
|
||||
if use_gpu >= 0:
|
||||
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
|
||||
self.optimizer = RemoteOptimizer.remote(config_path)
|
||||
self.rank = rank
|
||||
self.sync()
|
||||
|
||||
def sync(self):
|
||||
self.local_optimizer = ray.get(self.optimizer.fetch.remote())
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
weights, grads = ray.get(self.optimizer.call.remote(*args, **kwargs))
|
||||
weights, grads = ray.get(self.optimizer.call.remote(self.rank, *args, **kwargs))
|
||||
return weights.copy(), grads.copy()
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
|
|
@ -69,7 +69,7 @@ class AllreduceOptimizer:
|
|||
weights = self.allreduce(weights) / self.communicator.size()
|
||||
|
||||
|
||||
gradient = self.allreduce(gradient)
|
||||
gradient = self.allreduce(gradient) / self.communicator.size()
|
||||
flat_weights, gradient = self.optimizer(key, weights, gradient, lr_scale=lr_scale)
|
||||
return flat_weights, gradient
|
||||
|
||||
|
|
|
@ -198,7 +198,7 @@ def train_cli(
|
|||
|
||||
if num_workers and num_workers > 1:
|
||||
import ray
|
||||
ray.init()
|
||||
ray.init(address="auto")
|
||||
if strategy == "ps":
|
||||
from spacy.cli.ray_param_server import RayOptimizer
|
||||
remote_train = ray.remote(setup_and_train)
|
||||
|
@ -401,8 +401,12 @@ def train(
|
|||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||
print_row = setup_printer(training, nlp)
|
||||
|
||||
tqdm_args = dict(total=training["eval_frequency"], leave=False)
|
||||
global world_rank
|
||||
if world_rank is not None:
|
||||
tqdm_args["disable"] = bool(world_rank != 0)
|
||||
try:
|
||||
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
|
||||
progress = tqdm.tqdm(**tqdm_args)
|
||||
for batch, info, is_best_checkpoint in training_step_iterator:
|
||||
progress.update(1)
|
||||
if is_best_checkpoint is not None:
|
||||
|
@ -411,7 +415,7 @@ def train(
|
|||
if is_best_checkpoint and output_path is not None:
|
||||
update_meta(training, nlp, info)
|
||||
nlp.to_disk(output_path / "model-best")
|
||||
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
|
||||
progress = tqdm.tqdm(**tqdm_args)
|
||||
# Clean up the objects to faciliate garbage collection.
|
||||
for eg in batch:
|
||||
eg.doc = None
|
||||
|
@ -437,6 +441,10 @@ def train(
|
|||
|
||||
def create_train_batches(nlp, corpus, cfg):
|
||||
epochs_todo = cfg.get("max_epochs", 0)
|
||||
if world_rank is not None:
|
||||
for i in range(world_rank):
|
||||
# Increment random seed
|
||||
random.random()
|
||||
while True:
|
||||
train_examples = list(
|
||||
corpus.train_dataset(
|
||||
|
@ -452,16 +460,18 @@ def create_train_batches(nlp, corpus, cfg):
|
|||
raise ValueError(Errors.E988)
|
||||
random.shuffle(train_examples)
|
||||
|
||||
if world_size is not None:
|
||||
# Taken from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
|
||||
num_samples = int(math.ceil(len(train_examples) * 1.0 / world_size))
|
||||
total_size = num_samples * world_size # expected to overflow
|
||||
train_examples += train_examples[:(total_size - len(train_examples))]
|
||||
assert len(train_examples) == total_size
|
||||
# # TODO: with large batches, this can be bad.
|
||||
# if world_size is not None:
|
||||
# # Taken from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
|
||||
# num_samples = int(math.ceil(len(train_examples) * 1.0 / world_size))
|
||||
# total_size = num_samples * world_size # expected to overflow
|
||||
# train_examples += train_examples[:(total_size - len(train_examples))]
|
||||
# assert len(train_examples) == total_size
|
||||
|
||||
# subsample
|
||||
train_examples = train_examples[world_rank:total_size:world_size]
|
||||
assert len(train_examples) == num_samples
|
||||
# # subsample
|
||||
# train_examples = train_examples[world_rank:total_size:world_size]
|
||||
# assert len(train_examples) == num_samples
|
||||
# print(f"Reset epoch: Only using {num_samples} out of {total_size} samples")
|
||||
|
||||
batches = util.minibatch_by_words(
|
||||
train_examples,
|
||||
|
@ -474,7 +484,7 @@ def create_train_batches(nlp, corpus, cfg):
|
|||
yield first
|
||||
except StopIteration:
|
||||
raise ValueError(Errors.E986)
|
||||
for batch in batches:
|
||||
for i, batch in enumerate(batches):
|
||||
yield batch
|
||||
epochs_todo -= 1
|
||||
# We intentionally compare exactly to 0 here, so that max_epochs < 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user