mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-07 15:10:34 +03:00
with-gpu
This commit is contained in:
parent
54951aa976
commit
ef2af90f54
|
@ -31,8 +31,10 @@ class OptimizerWorker:
|
||||||
class RayOptimizer:
|
class RayOptimizer:
|
||||||
local_optimizer = None
|
local_optimizer = None
|
||||||
|
|
||||||
def __init__(self, config_path):
|
def __init__(self, config_path, use_gpu):
|
||||||
RemoteOptimizer = ray.remote(OptimizerWorker)
|
RemoteOptimizer = ray.remote(OptimizerWorker)
|
||||||
|
if use_gpu >= 0:
|
||||||
|
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
|
||||||
self.optimizer = RemoteOptimizer.remote(config_path)
|
self.optimizer = RemoteOptimizer.remote(config_path)
|
||||||
self.sync()
|
self.sync()
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import math
|
||||||
import srsly
|
import srsly
|
||||||
from pydantic import BaseModel, FilePath
|
from pydantic import BaseModel, FilePath
|
||||||
import plac
|
import plac
|
||||||
|
import os
|
||||||
import tqdm
|
import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
|
@ -198,7 +199,11 @@ def train_cli(
|
||||||
import ray
|
import ray
|
||||||
ray.init()
|
ray.init()
|
||||||
remote_train = ray.remote(setup_and_train)
|
remote_train = ray.remote(setup_and_train)
|
||||||
train_args["remote_optimizer"] = RayOptimizer(config_path)
|
if use_gpu >= 0:
|
||||||
|
msg.info("Enabling GPU with Ray")
|
||||||
|
remote_train = remote_train.options(num_gpus=0.9)
|
||||||
|
|
||||||
|
train_args["remote_optimizer"] = RayOptimizer(config_path, use_gpu=use_gpu)
|
||||||
ray.get([remote_train.remote(
|
ray.get([remote_train.remote(
|
||||||
use_gpu,
|
use_gpu,
|
||||||
train_args,
|
train_args,
|
||||||
|
@ -210,17 +215,19 @@ def train_cli(
|
||||||
world_rank = None
|
world_rank = None
|
||||||
world_size = None
|
world_size = None
|
||||||
|
|
||||||
def setup_and_train(use_gpu, train_args, rank, total_workers):
|
def setup_and_train(use_gpu, train_args, rank=None, total_workers=None):
|
||||||
if use_gpu >= 0:
|
if rank is not None:
|
||||||
msg.info("Using GPU: {use_gpu}")
|
|
||||||
util.use_gpu(use_gpu)
|
|
||||||
else:
|
|
||||||
msg.info("Using CPU")
|
|
||||||
if rank:
|
|
||||||
global world_rank
|
global world_rank
|
||||||
world_rank = rank
|
world_rank = rank
|
||||||
global world_size
|
global world_size
|
||||||
world_size = total_workers
|
world_size = total_workers
|
||||||
|
if use_gpu >= 0:
|
||||||
|
use_gpu = 0
|
||||||
|
if use_gpu >= 0:
|
||||||
|
msg.info(f"Using GPU: {use_gpu}")
|
||||||
|
util.use_gpu(use_gpu)
|
||||||
|
else:
|
||||||
|
msg.info("Using CPU")
|
||||||
train(**train_args)
|
train(**train_args)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user