mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-03 10:55:52 +03:00
distributed
This commit is contained in:
parent
82a3a3b9a7
commit
54951aa976
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional, Dict, List, Union, Sequence
|
||||
from timeit import default_timer as timer
|
||||
|
||||
import math
|
||||
import srsly
|
||||
from pydantic import BaseModel, FilePath
|
||||
import plac
|
||||
|
@ -199,16 +199,28 @@ def train_cli(
|
|||
ray.init()
|
||||
remote_train = ray.remote(setup_and_train)
|
||||
train_args["remote_optimizer"] = RayOptimizer(config_path)
|
||||
ray.get([remote_train.remote(use_gpu, train_args) for _ in range(num_workers)])
|
||||
ray.get([remote_train.remote(
|
||||
use_gpu,
|
||||
train_args,
|
||||
rank=rank,
|
||||
total_workers=num_workers) for rank in range(num_workers)])
|
||||
else:
|
||||
setup_and_train(use_gpu, train_args)
|
||||
|
||||
def setup_and_train(use_gpu, train_args):
|
||||
world_rank = None
|
||||
world_size = None
|
||||
|
||||
def setup_and_train(use_gpu, train_args, rank, total_workers):
|
||||
if use_gpu >= 0:
|
||||
msg.info("Using GPU: {use_gpu}")
|
||||
util.use_gpu(use_gpu)
|
||||
else:
|
||||
msg.info("Using CPU")
|
||||
if rank:
|
||||
global world_rank
|
||||
world_rank = rank
|
||||
global world_size
|
||||
world_size = total_workers
|
||||
train(**train_args)
|
||||
|
||||
def train(
|
||||
|
@ -226,7 +238,7 @@ def train(
|
|||
config = util.load_config(config_path, create_objects=False)
|
||||
util.fix_random_seed(config["training"]["seed"])
|
||||
if config["training"].get("use_pytorch_for_gpu_memory"):
|
||||
# It feels kind of weird to not have a default for this.
|
||||
# It feels kind of weird to not have a default for this.
|
||||
use_pytorch_for_gpu_memory()
|
||||
nlp_config = config["nlp"]
|
||||
config = util.load_config(config_path, create_objects=True)
|
||||
|
@ -401,6 +413,18 @@ def create_train_batches(nlp, corpus, cfg):
|
|||
if len(train_examples) == 0:
|
||||
raise ValueError(Errors.E988)
|
||||
random.shuffle(train_examples)
|
||||
|
||||
if world_size is not None:
|
||||
# Taken from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
|
||||
num_samples = int(math.ceil(len(train_examples) * 1.0 / world_size))
|
||||
total_size = num_samples * world_size # expected to overflow
|
||||
train_examples += train_examples[:(total_size - len(train_examples))]
|
||||
assert len(train_examples) == total_size
|
||||
|
||||
# subsample
|
||||
train_examples = train_examples[world_rank:total_size:world_size]
|
||||
assert len(train_examples) == num_samples
|
||||
|
||||
batches = util.minibatch_by_words(
|
||||
train_examples,
|
||||
size=cfg["batch_size"],
|
||||
|
|
Loading…
Reference in New Issue
Block a user