From 54951aa976fcae7fb2b9b0df276c216fdc944fa6 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Mon, 15 Jun 2020 01:30:25 -0700 Subject: [PATCH] distributed --- spacy/cli/train_from_config.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/spacy/cli/train_from_config.py b/spacy/cli/train_from_config.py index 597bb3058..4bc7dc33e 100644 --- a/spacy/cli/train_from_config.py +++ b/spacy/cli/train_from_config.py @@ -1,6 +1,6 @@ from typing import Optional, Dict, List, Union, Sequence from timeit import default_timer as timer - +import math import srsly from pydantic import BaseModel, FilePath import plac @@ -199,16 +199,28 @@ def train_cli( ray.init() remote_train = ray.remote(setup_and_train) train_args["remote_optimizer"] = RayOptimizer(config_path) - ray.get([remote_train.remote(use_gpu, train_args) for _ in range(num_workers)]) + ray.get([remote_train.remote( + use_gpu, + train_args, + rank=rank, + total_workers=num_workers) for rank in range(num_workers)]) else: setup_and_train(use_gpu, train_args) -def setup_and_train(use_gpu, train_args): +world_rank = None +world_size = None + +def setup_and_train(use_gpu, train_args, rank, total_workers): if use_gpu >= 0: msg.info("Using GPU: {use_gpu}") util.use_gpu(use_gpu) else: msg.info("Using CPU") + if rank: + global world_rank + world_rank = rank + global world_size + world_size = total_workers train(**train_args) def train( @@ -226,7 +238,7 @@ def train( config = util.load_config(config_path, create_objects=False) util.fix_random_seed(config["training"]["seed"]) if config["training"].get("use_pytorch_for_gpu_memory"): - # It feels kind of weird to not have a default for this. + # It feels kind of weird to not have a default for this. use_pytorch_for_gpu_memory() nlp_config = config["nlp"] config = util.load_config(config_path, create_objects=True) @@ -401,6 +413,18 @@ def create_train_batches(nlp, corpus, cfg): if len(train_examples) == 0: raise ValueError(Errors.E988) random.shuffle(train_examples) + + if world_size is not None: + # Taken from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py + num_samples = int(math.ceil(len(train_examples) * 1.0 / world_size)) + total_size = num_samples * world_size # expected to overflow + train_examples += train_examples[:(total_size - len(train_examples))] + assert len(train_examples) == total_size + + # subsample + train_examples = train_examples[world_rank:total_size:world_size] + assert len(train_examples) == num_samples + batches = util.minibatch_by_words( train_examples, size=cfg["batch_size"],