mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-07 15:10:34 +03:00
with-gpu
This commit is contained in:
parent
54951aa976
commit
ef2af90f54
|
@ -31,8 +31,10 @@ class OptimizerWorker:
|
|||
class RayOptimizer:
|
||||
local_optimizer = None
|
||||
|
||||
def __init__(self, config_path):
|
||||
def __init__(self, config_path, use_gpu):
|
||||
RemoteOptimizer = ray.remote(OptimizerWorker)
|
||||
if use_gpu >= 0:
|
||||
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
|
||||
self.optimizer = RemoteOptimizer.remote(config_path)
|
||||
self.sync()
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import math
|
|||
import srsly
|
||||
from pydantic import BaseModel, FilePath
|
||||
import plac
|
||||
import os
|
||||
import tqdm
|
||||
from pathlib import Path
|
||||
from wasabi import msg
|
||||
|
@ -198,7 +199,11 @@ def train_cli(
|
|||
import ray
|
||||
ray.init()
|
||||
remote_train = ray.remote(setup_and_train)
|
||||
train_args["remote_optimizer"] = RayOptimizer(config_path)
|
||||
if use_gpu >= 0:
|
||||
msg.info("Enabling GPU with Ray")
|
||||
remote_train = remote_train.options(num_gpus=0.9)
|
||||
|
||||
train_args["remote_optimizer"] = RayOptimizer(config_path, use_gpu=use_gpu)
|
||||
ray.get([remote_train.remote(
|
||||
use_gpu,
|
||||
train_args,
|
||||
|
@ -210,17 +215,19 @@ def train_cli(
|
|||
world_rank = None
|
||||
world_size = None
|
||||
|
||||
def setup_and_train(use_gpu, train_args, rank, total_workers):
|
||||
if use_gpu >= 0:
|
||||
msg.info("Using GPU: {use_gpu}")
|
||||
util.use_gpu(use_gpu)
|
||||
else:
|
||||
msg.info("Using CPU")
|
||||
if rank:
|
||||
def setup_and_train(use_gpu, train_args, rank=None, total_workers=None):
|
||||
if rank is not None:
|
||||
global world_rank
|
||||
world_rank = rank
|
||||
global world_size
|
||||
world_size = total_workers
|
||||
if use_gpu >= 0:
|
||||
use_gpu = 0
|
||||
if use_gpu >= 0:
|
||||
msg.info(f"Using GPU: {use_gpu}")
|
||||
util.use_gpu(use_gpu)
|
||||
else:
|
||||
msg.info("Using CPU")
|
||||
train(**train_args)
|
||||
|
||||
def train(
|
||||
|
|
Loading…
Reference in New Issue
Block a user