mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-28 10:14:07 +03:00
Fix train CLI
This commit is contained in:
parent
4c5d6b13c8
commit
5f728dd24d
|
@ -156,7 +156,7 @@ def train_cli(
|
||||||
with init_tok2vec.open("rb") as file_:
|
with init_tok2vec.open("rb") as file_:
|
||||||
weights_data = file_.read()
|
weights_data = file_.read()
|
||||||
|
|
||||||
if num_workers and num_workers > 1:
|
if num_workers and num_workers >= 1:
|
||||||
from _ray_async_utils import distributed_setup_and_train
|
from _ray_async_utils import distributed_setup_and_train
|
||||||
distributed_setup_and_train(
|
distributed_setup_and_train(
|
||||||
use_gpu,
|
use_gpu,
|
||||||
|
@ -172,6 +172,8 @@ def train_cli(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
|
if use_gpu >= 0:
|
||||||
|
require_gpu(use_gpu)
|
||||||
nlp, config = load_nlp_and_config(config_path)
|
nlp, config = load_nlp_and_config(config_path)
|
||||||
corpus = Corpus(train_path, dev_path, limit=config["training"]["limit"])
|
corpus = Corpus(train_path, dev_path, limit=config["training"]["limit"])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user