mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-30 19:24:07 +03:00
port
This commit is contained in:
parent
610dfd85c2
commit
ab50385986
|
@ -132,7 +132,7 @@ def train_cli(
|
|||
verbose: bool = Opt(False, "--verbose", "-VV", help="Display more information for debugging purposes"),
|
||||
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
||||
num_workers: int = Opt(None, "-j", help="Parallel Workers"),
|
||||
strategy: str = Opt(None, "--strategy", help="Distributed training strategy (requires spacy_ray)"),
|
||||
strategy: str = Opt("allreduce", "--strategy", help="Distributed training strategy (requires spacy_ray)"),
|
||||
ray_address: str = Opt(None, "--address", help="Address of the Ray cluster. Multi-node training (requires spacy_ray)"),
|
||||
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map"),
|
||||
omit_extra_lookups: bool = Opt(False, "--omit-extra-lookups", "-OEL", help="Don't include extra lookups in model"),
|
||||
|
@ -168,10 +168,7 @@ def train_cli(
|
|||
)
|
||||
|
||||
if num_workers and num_workers > 1:
|
||||
try:
|
||||
from spacy_ray import distributed_setup_and_train
|
||||
except ImportError:
|
||||
msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1)
|
||||
from spacy_ray import distributed_setup_and_train
|
||||
distributed_setup_and_train(use_gpu, num_workers, strategy, ray_address, train_args)
|
||||
else:
|
||||
if use_gpu >= 0:
|
||||
|
@ -190,7 +187,7 @@ def train(
|
|||
weights_data: Optional[bytes] = None,
|
||||
omit_extra_lookups: bool = False,
|
||||
disable_tqdm: bool = False,
|
||||
remote_optimizer: Optimizer = None,
|
||||
remote_optimizer = None,
|
||||
randomization_index: int = 0
|
||||
) -> None:
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
|
@ -321,6 +318,8 @@ def create_train_batches(nlp, corpus, cfg, randomization_index):
|
|||
while True:
|
||||
if len(train_examples) == 0:
|
||||
raise ValueError(Errors.E988)
|
||||
# This is used when doing parallel training to
|
||||
# ensure that the dataset is shuffled differently across all workers.
|
||||
for _ in range(randomization_index):
|
||||
random.random()
|
||||
random.shuffle(train_examples)
|
||||
|
@ -564,6 +563,9 @@ def verify_cli_args(
|
|||
raw_text=None,
|
||||
verbose=False,
|
||||
use_gpu=-1,
|
||||
num_workers=None,
|
||||
strategy=None,
|
||||
ray_address=None,
|
||||
tag_map_path=None,
|
||||
omit_extra_lookups=False,
|
||||
):
|
||||
|
@ -596,6 +598,12 @@ def verify_cli_args(
|
|||
if init_tok2vec is not None and not init_tok2vec.exists():
|
||||
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||
|
||||
if num_workers and num_workers > 1:
|
||||
try:
|
||||
from spacy_ray import distributed_setup_and_train
|
||||
except ImportError:
|
||||
msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1)
|
||||
|
||||
|
||||
def verify_textcat_config(nlp, nlp_config):
|
||||
# if 'positive_label' is provided: double check whether it's in the data and
|
||||
|
|
Loading…
Reference in New Issue
Block a user