diff --git a/spacy/cli/ray_utils.py b/spacy/cli/ray_utils.py index 242721223..1a7c1f0e6 100644 --- a/spacy/cli/ray_utils.py +++ b/spacy/cli/ray_utils.py @@ -31,8 +31,10 @@ class OptimizerWorker: class RayOptimizer: local_optimizer = None - def __init__(self, config_path): + def __init__(self, config_path, use_gpu): RemoteOptimizer = ray.remote(OptimizerWorker) + if use_gpu >= 0: + RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1) self.optimizer = RemoteOptimizer.remote(config_path) self.sync() diff --git a/spacy/cli/train_from_config.py b/spacy/cli/train_from_config.py index 4bc7dc33e..7a4bfece7 100644 --- a/spacy/cli/train_from_config.py +++ b/spacy/cli/train_from_config.py @@ -4,6 +4,7 @@ import math import srsly from pydantic import BaseModel, FilePath import plac +import os import tqdm from pathlib import Path from wasabi import msg @@ -198,7 +199,11 @@ def train_cli( import ray ray.init() remote_train = ray.remote(setup_and_train) - train_args["remote_optimizer"] = RayOptimizer(config_path) + if use_gpu >= 0: + msg.info("Enabling GPU with Ray") + remote_train = remote_train.options(num_gpus=0.9) + + train_args["remote_optimizer"] = RayOptimizer(config_path, use_gpu=use_gpu) ray.get([remote_train.remote( use_gpu, train_args, @@ -210,17 +215,19 @@ def train_cli( world_rank = None world_size = None -def setup_and_train(use_gpu, train_args, rank, total_workers): - if use_gpu >= 0: - msg.info("Using GPU: {use_gpu}") - util.use_gpu(use_gpu) - else: - msg.info("Using CPU") - if rank: +def setup_and_train(use_gpu, train_args, rank=None, total_workers=None): + if rank is not None: global world_rank world_rank = rank global world_size world_size = total_workers + if use_gpu >= 0: + use_gpu = 0 + if use_gpu >= 0: + msg.info(f"Using GPU: {use_gpu}") + util.use_gpu(use_gpu) + else: + msg.info("Using CPU") train(**train_args) def train(