mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-01 10:23:07 +03:00
address
This commit is contained in:
parent
bd679cd8c7
commit
ad6448d602
|
@ -128,6 +128,9 @@ class ConfigSchema(BaseModel):
|
||||||
use_gpu=("Use GPU", "option", "g", int),
|
use_gpu=("Use GPU", "option", "g", int),
|
||||||
num_workers=("Parallel Workers", "option", "j", int),
|
num_workers=("Parallel Workers", "option", "j", int),
|
||||||
strategy=("Distributed training strategy (requires spacy_ray)", "option", "strat", str),
|
strategy=("Distributed training strategy (requires spacy_ray)", "option", "strat", str),
|
||||||
|
ray_address=(
|
||||||
|
"Address of Ray cluster. Only required for multi-node training (requires spacy_ray)",
|
||||||
|
"option", "address", str),
|
||||||
tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
|
tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
|
||||||
omit_extra_lookups=("Don't include extra lookups in model", "flag", "OEL", bool),
|
omit_extra_lookups=("Don't include extra lookups in model", "flag", "OEL", bool),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
@ -143,6 +146,7 @@ def train_cli(
|
||||||
use_gpu=-1,
|
use_gpu=-1,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
strategy="allreduce",
|
strategy="allreduce",
|
||||||
|
ray_address=None,
|
||||||
tag_map_path=None,
|
tag_map_path=None,
|
||||||
omit_extra_lookups=False,
|
omit_extra_lookups=False,
|
||||||
):
|
):
|
||||||
|
@ -200,7 +204,7 @@ def train_cli(
|
||||||
from spacy_ray import distributed_setup_and_train
|
from spacy_ray import distributed_setup_and_train
|
||||||
except ImportError:
|
except ImportError:
|
||||||
msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1)
|
msg.fail("Need to `pip install spacy_ray` to use distributed training!", exits=1)
|
||||||
distributed_setup_and_train(use_gpu, num_workers, strategy, train_args)
|
distributed_setup_and_train(use_gpu, num_workers, strategy, ray_address, train_args)
|
||||||
else:
|
else:
|
||||||
if use_gpu >= 0:
|
if use_gpu >= 0:
|
||||||
msg.info(f"Using GPU: {use_gpu}")
|
msg.info(f"Using GPU: {use_gpu}")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user