diff --git a/spacy/cli/train.py b/spacy/cli/train.py index f6e585edc..fde88d97c 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -156,7 +156,7 @@ def train_cli( with init_tok2vec.open("rb") as file_: weights_data = file_.read() - if num_workers and num_workers > 1: + if num_workers and num_workers >= 1: from _ray_async_utils import distributed_setup_and_train distributed_setup_and_train( use_gpu, @@ -172,6 +172,8 @@ def train_cli( ) else: msg.info(f"Loading config from: {config_path}") + if use_gpu >= 0: + require_gpu(use_gpu) nlp, config = load_nlp_and_config(config_path) corpus = Corpus(train_path, dev_path, limit=config["training"]["limit"])