diff --git a/spacy/cli/train_from_config.py b/spacy/cli/train_from_config.py index dc8e8fba2..8f4c25cd4 100644 --- a/spacy/cli/train_from_config.py +++ b/spacy/cli/train_from_config.py @@ -128,6 +128,9 @@ class ConfigSchema(BaseModel): use_gpu=("Use GPU", "option", "g", int), num_workers=("Parallel Workers", "option", "j", int), strategy=("Distributed training strategy (requires spacy_ray)", "option", "strat", str), + ray_address=( + "Address of Ray cluster. Only required for multi-node training (requires spacy_ray)", + "option", "address", str), tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path), omit_extra_lookups=("Don't include extra lookups in model", "flag", "OEL", bool), # fmt: on @@ -143,6 +146,7 @@ def train_cli( use_gpu=-1, num_workers=1, strategy="allreduce", + ray_address=None, tag_map_path=None, omit_extra_lookups=False, ): @@ -200,7 +204,7 @@ def train_cli( from spacy_ray import distributed_setup_and_train except ImportError: msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1) - distributed_setup_and_train(use_gpu, num_workers, strategy, train_args) + distributed_setup_and_train(use_gpu, num_workers, strategy, ray_address, train_args) else: if use_gpu >= 0: msg.info(f"Using GPU: {use_gpu}")