mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-31 03:34:07 +03:00
port
This commit is contained in:
parent
610dfd85c2
commit
ab50385986
|
@ -132,7 +132,7 @@ def train_cli(
|
||||||
verbose: bool = Opt(False, "--verbose", "-VV", help="Display more information for debugging purposes"),
|
verbose: bool = Opt(False, "--verbose", "-VV", help="Display more information for debugging purposes"),
|
||||||
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
||||||
num_workers: int = Opt(None, "-j", help="Parallel Workers"),
|
num_workers: int = Opt(None, "-j", help="Parallel Workers"),
|
||||||
strategy: str = Opt(None, "--strategy", help="Distributed training strategy (requires spacy_ray)"),
|
strategy: str = Opt("allreduce", "--strategy", help="Distributed training strategy (requires spacy_ray)"),
|
||||||
ray_address: str = Opt(None, "--address", help="Address of the Ray cluster. Multi-node training (requires spacy_ray)"),
|
ray_address: str = Opt(None, "--address", help="Address of the Ray cluster. Multi-node training (requires spacy_ray)"),
|
||||||
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map"),
|
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map"),
|
||||||
omit_extra_lookups: bool = Opt(False, "--omit-extra-lookups", "-OEL", help="Don't include extra lookups in model"),
|
omit_extra_lookups: bool = Opt(False, "--omit-extra-lookups", "-OEL", help="Don't include extra lookups in model"),
|
||||||
|
@ -168,10 +168,7 @@ def train_cli(
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_workers and num_workers > 1:
|
if num_workers and num_workers > 1:
|
||||||
try:
|
from spacy_ray import distributed_setup_and_train
|
||||||
from spacy_ray import distributed_setup_and_train
|
|
||||||
except ImportError:
|
|
||||||
msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1)
|
|
||||||
distributed_setup_and_train(use_gpu, num_workers, strategy, ray_address, train_args)
|
distributed_setup_and_train(use_gpu, num_workers, strategy, ray_address, train_args)
|
||||||
else:
|
else:
|
||||||
if use_gpu >= 0:
|
if use_gpu >= 0:
|
||||||
|
@ -190,7 +187,7 @@ def train(
|
||||||
weights_data: Optional[bytes] = None,
|
weights_data: Optional[bytes] = None,
|
||||||
omit_extra_lookups: bool = False,
|
omit_extra_lookups: bool = False,
|
||||||
disable_tqdm: bool = False,
|
disable_tqdm: bool = False,
|
||||||
remote_optimizer: Optimizer = None,
|
remote_optimizer = None,
|
||||||
randomization_index: int = 0
|
randomization_index: int = 0
|
||||||
) -> None:
|
) -> None:
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
|
@ -321,6 +318,8 @@ def create_train_batches(nlp, corpus, cfg, randomization_index):
|
||||||
while True:
|
while True:
|
||||||
if len(train_examples) == 0:
|
if len(train_examples) == 0:
|
||||||
raise ValueError(Errors.E988)
|
raise ValueError(Errors.E988)
|
||||||
|
# This is used when doing parallel training to
|
||||||
|
# ensure that the dataset is shuffled differently across all workers.
|
||||||
for _ in range(randomization_index):
|
for _ in range(randomization_index):
|
||||||
random.random()
|
random.random()
|
||||||
random.shuffle(train_examples)
|
random.shuffle(train_examples)
|
||||||
|
@ -564,6 +563,9 @@ def verify_cli_args(
|
||||||
raw_text=None,
|
raw_text=None,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
use_gpu=-1,
|
use_gpu=-1,
|
||||||
|
num_workers=None,
|
||||||
|
strategy=None,
|
||||||
|
ray_address=None,
|
||||||
tag_map_path=None,
|
tag_map_path=None,
|
||||||
omit_extra_lookups=False,
|
omit_extra_lookups=False,
|
||||||
):
|
):
|
||||||
|
@ -596,6 +598,12 @@ def verify_cli_args(
|
||||||
if init_tok2vec is not None and not init_tok2vec.exists():
|
if init_tok2vec is not None and not init_tok2vec.exists():
|
||||||
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||||
|
|
||||||
|
if num_workers and num_workers > 1:
|
||||||
|
try:
|
||||||
|
from spacy_ray import distributed_setup_and_train
|
||||||
|
except ImportError:
|
||||||
|
msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1)
|
||||||
|
|
||||||
|
|
||||||
def verify_textcat_config(nlp, nlp_config):
|
def verify_textcat_config(nlp, nlp_config):
|
||||||
# if 'positive_label' is provided: double check whether it's in the data and
|
# if 'positive_label' is provided: double check whether it's in the data and
|
||||||
|
|
Loading…
Reference in New Issue
Block a user