This commit is contained in:
Richard Liaw 2020-06-30 16:05:20 -07:00
parent 610dfd85c2
commit ab50385986

View File

@ -132,7 +132,7 @@ def train_cli(
verbose: bool = Opt(False, "--verbose", "-VV", help="Display more information for debugging purposes"),
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
num_workers: int = Opt(None, "-j", help="Parallel Workers"),
strategy: str = Opt(None, "--strategy", help="Distributed training strategy (requires spacy_ray)"),
strategy: str = Opt("allreduce", "--strategy", help="Distributed training strategy (requires spacy_ray)"),
ray_address: str = Opt(None, "--address", help="Address of the Ray cluster. Multi-node training (requires spacy_ray)"),
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map"),
omit_extra_lookups: bool = Opt(False, "--omit-extra-lookups", "-OEL", help="Don't include extra lookups in model"),
@ -168,10 +168,7 @@ def train_cli(
)
if num_workers and num_workers > 1:
try:
from spacy_ray import distributed_setup_and_train
except ImportError:
msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1)
from spacy_ray import distributed_setup_and_train
distributed_setup_and_train(use_gpu, num_workers, strategy, ray_address, train_args)
else:
if use_gpu >= 0:
@ -190,7 +187,7 @@ def train(
weights_data: Optional[bytes] = None,
omit_extra_lookups: bool = False,
disable_tqdm: bool = False,
remote_optimizer: Optimizer = None,
remote_optimizer = None,
randomization_index: int = 0
) -> None:
msg.info(f"Loading config from: {config_path}")
@ -321,6 +318,8 @@ def create_train_batches(nlp, corpus, cfg, randomization_index):
while True:
if len(train_examples) == 0:
raise ValueError(Errors.E988)
# This is used when doing parallel training to
# ensure that the dataset is shuffled differently across all workers.
for _ in range(randomization_index):
random.random()
random.shuffle(train_examples)
@ -564,6 +563,9 @@ def verify_cli_args(
raw_text=None,
verbose=False,
use_gpu=-1,
num_workers=None,
strategy=None,
ray_address=None,
tag_map_path=None,
omit_extra_lookups=False,
):
@ -596,6 +598,12 @@ def verify_cli_args(
if init_tok2vec is not None and not init_tok2vec.exists():
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
if num_workers and num_workers > 1:
try:
from spacy_ray import distributed_setup_and_train
except ImportError:
msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1)
def verify_textcat_config(nlp, nlp_config):
# if 'positive_label' is provided: double check whether it's in the data and