distributed

This commit is contained in:
Richard Liaw 2020-06-15 01:30:25 -07:00
parent 82a3a3b9a7
commit 54951aa976

View File

@ -1,6 +1,6 @@
from typing import Optional, Dict, List, Union, Sequence from typing import Optional, Dict, List, Union, Sequence
from timeit import default_timer as timer from timeit import default_timer as timer
import math
import srsly import srsly
from pydantic import BaseModel, FilePath from pydantic import BaseModel, FilePath
import plac import plac
@ -199,16 +199,28 @@ def train_cli(
ray.init() ray.init()
remote_train = ray.remote(setup_and_train) remote_train = ray.remote(setup_and_train)
train_args["remote_optimizer"] = RayOptimizer(config_path) train_args["remote_optimizer"] = RayOptimizer(config_path)
ray.get([remote_train.remote(use_gpu, train_args) for _ in range(num_workers)]) ray.get([remote_train.remote(
use_gpu,
train_args,
rank=rank,
total_workers=num_workers) for rank in range(num_workers)])
else: else:
setup_and_train(use_gpu, train_args) setup_and_train(use_gpu, train_args)
def setup_and_train(use_gpu, train_args): world_rank = None
world_size = None
def setup_and_train(use_gpu, train_args, rank, total_workers):
if use_gpu >= 0: if use_gpu >= 0:
msg.info("Using GPU: {use_gpu}") msg.info("Using GPU: {use_gpu}")
util.use_gpu(use_gpu) util.use_gpu(use_gpu)
else: else:
msg.info("Using CPU") msg.info("Using CPU")
if rank:
global world_rank
world_rank = rank
global world_size
world_size = total_workers
train(**train_args) train(**train_args)
def train( def train(
@ -226,7 +238,7 @@ def train(
config = util.load_config(config_path, create_objects=False) config = util.load_config(config_path, create_objects=False)
util.fix_random_seed(config["training"]["seed"]) util.fix_random_seed(config["training"]["seed"])
if config["training"].get("use_pytorch_for_gpu_memory"): if config["training"].get("use_pytorch_for_gpu_memory"):
# It feels kind of weird to not have a default for this. # It feels kind of weird to not have a default for this.
use_pytorch_for_gpu_memory() use_pytorch_for_gpu_memory()
nlp_config = config["nlp"] nlp_config = config["nlp"]
config = util.load_config(config_path, create_objects=True) config = util.load_config(config_path, create_objects=True)
@ -401,6 +413,18 @@ def create_train_batches(nlp, corpus, cfg):
if len(train_examples) == 0: if len(train_examples) == 0:
raise ValueError(Errors.E988) raise ValueError(Errors.E988)
random.shuffle(train_examples) random.shuffle(train_examples)
if world_size is not None:
# Taken from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
num_samples = int(math.ceil(len(train_examples) * 1.0 / world_size))
total_size = num_samples * world_size # expected to overflow
train_examples += train_examples[:(total_size - len(train_examples))]
assert len(train_examples) == total_size
# subsample
train_examples = train_examples[world_rank:total_size:world_size]
assert len(train_examples) == num_samples
batches = util.minibatch_by_words( batches = util.minibatch_by_words(
train_examples, train_examples,
size=cfg["batch_size"], size=cfg["batch_size"],