From ab50385986edb1850ae324431c363dd40ed5cd12 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Tue, 30 Jun 2020 16:05:20 -0700 Subject: [PATCH] port --- spacy/cli/train.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index f495cd4c8..9cfd1e4c6 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -132,7 +132,7 @@ def train_cli( verbose: bool = Opt(False, "--verbose", "-VV", help="Display more information for debugging purposes"), use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"), num_workers: int = Opt(None, "-j", help="Parallel Workers"), - strategy: str = Opt(None, "--strategy", help="Distributed training strategy (requires spacy_ray)"), + strategy: str = Opt("allreduce", "--strategy", help="Distributed training strategy (requires spacy_ray)"), ray_address: str = Opt(None, "--address", help="Address of the Ray cluster. Multi-node training (requires spacy_ray)"), tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map"), omit_extra_lookups: bool = Opt(False, "--omit-extra-lookups", "-OEL", help="Don't include extra lookups in model"), @@ -168,10 +168,7 @@ def train_cli( ) if num_workers and num_workers > 1: - try: - from spacy_ray import distributed_setup_and_train - except ImportError: - msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1) + from spacy_ray import distributed_setup_and_train distributed_setup_and_train(use_gpu, num_workers, strategy, ray_address, train_args) else: if use_gpu >= 0: @@ -190,7 +187,7 @@ def train( weights_data: Optional[bytes] = None, omit_extra_lookups: bool = False, disable_tqdm: bool = False, - remote_optimizer: Optimizer = None, + remote_optimizer = None, randomization_index: int = 0 ) -> None: msg.info(f"Loading config from: {config_path}") @@ -321,6 +318,8 @@ def create_train_batches(nlp, corpus, cfg, randomization_index): while True: if len(train_examples) == 0: raise ValueError(Errors.E988) + # This is used when doing parallel training to + # ensure that the dataset is shuffled differently across all workers. for _ in range(randomization_index): random.random() random.shuffle(train_examples) @@ -564,6 +563,9 @@ def verify_cli_args( raw_text=None, verbose=False, use_gpu=-1, + num_workers=None, + strategy=None, + ray_address=None, tag_map_path=None, omit_extra_lookups=False, ): @@ -596,6 +598,12 @@ def verify_cli_args( if init_tok2vec is not None and not init_tok2vec.exists(): msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1) + if num_workers and num_workers > 1: + try: + from spacy_ray import distributed_setup_and_train + except ImportError: + msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1) + def verify_textcat_config(nlp, nlp_config): # if 'positive_label' is provided: double check whether it's in the data and