mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-31 03:34:07 +03:00
more-train-fixes
This commit is contained in:
parent
26c975ec66
commit
c865d833dc
|
@ -21,24 +21,24 @@ class OptimizerWorker:
|
||||||
|
|
||||||
if self.waiting < self.world_size - 1:
|
if self.waiting < self.world_size - 1:
|
||||||
if self.waiting == 0:
|
if self.waiting == 0:
|
||||||
self.gradient[key] = gradient.copy()
|
self.grad_dict[key] = gradient.copy()
|
||||||
self.weights_dict[key] = weights.copy()
|
self.weights_dict[key] = weights.copy()
|
||||||
else:
|
else:
|
||||||
self.gradient[key] += gradient
|
self.grad_dict[key] += gradient
|
||||||
self.waiting = self.barrier.n_waiting + 1
|
self.waiting = self.barrier.n_waiting + 1
|
||||||
self.lock.release()
|
self.lock.release()
|
||||||
self.barrier.wait()
|
self.barrier.wait()
|
||||||
else:
|
else:
|
||||||
self.gradient[key] += gradient
|
self.grad_dict[key] += gradient
|
||||||
self.lock.release()
|
self.lock.release()
|
||||||
self.gradient[key] /= self.world_size
|
self.grad_dict[key] /= self.world_size
|
||||||
new_weights, new_grads = self.optimizer(
|
new_weights, new_grads = self.optimizer(
|
||||||
key, self.weights_dict[key], self.gradient[key], lr_scale=lr_scale)
|
key, self.weights_dict[key], self.grad_dict[key], lr_scale=lr_scale)
|
||||||
self.weights_dict[key] = new_weights
|
self.weights_dict[key] = new_weights
|
||||||
self.gradient[key] = new_grads
|
self.grad_dict[key] = new_grads
|
||||||
self.waiting = 0
|
self.waiting = 0
|
||||||
self.barrier.wait()
|
self.barrier.wait()
|
||||||
return self.weights_dict[key], self.gradient[key]
|
return self.weights_dict[key], self.grad_dict[key]
|
||||||
|
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
return self.optimizer
|
return self.optimizer
|
||||||
|
@ -49,12 +49,13 @@ class OptimizerWorker:
|
||||||
class RayOptimizer:
|
class RayOptimizer:
|
||||||
local_optimizer = None
|
local_optimizer = None
|
||||||
|
|
||||||
def __init__(self, config_path, use_gpu, rank):
|
def __init__(self, config_path, use_gpu, world_size):
|
||||||
RemoteOptimizer = ray.remote(OptimizerWorker)
|
RemoteOptimizer = ray.remote(OptimizerWorker)
|
||||||
|
options = {"max_concurrency": world_size}
|
||||||
if use_gpu >= 0:
|
if use_gpu >= 0:
|
||||||
RemoteOptimizer = RemoteOptimizer.options(num_gpus=0.1)
|
options["num_gpus"] = 0.1
|
||||||
self.optimizer = RemoteOptimizer.remote(config_path)
|
RemoteOptimizer = RemoteOptimizer.options(**options)
|
||||||
self.rank = rank
|
self.optimizer = RemoteOptimizer.remote(config_path, world_size)
|
||||||
self.sync()
|
self.sync()
|
||||||
|
|
||||||
def sync(self):
|
def sync(self):
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
"""Allreduce distributed training with Ray."""
|
"""Allreduce distributed training with Ray."""
|
||||||
|
|
||||||
|
import random
|
||||||
import ray
|
import ray
|
||||||
|
import numpy
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
|
@ -10,10 +12,11 @@ nccl = None
|
||||||
from typing import Dict, Optional, Union, Tuple, List, cast
|
from typing import Dict, Optional, Union, Tuple, List, cast
|
||||||
from thinc.types import FloatsXd
|
from thinc.types import FloatsXd
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(config_path):
|
def create_optimizer(config_path):
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
config = util.load_config(config_path, create_objects=False)
|
config = util.load_config(config_path, create_objects=False)
|
||||||
util.fix_random_seed(config["training"]["seed"]) # Fix this.
|
util.fix_random_seed(config["training"]["seed"])
|
||||||
config = util.load_config(config_path, create_objects=True)
|
config = util.load_config(config_path, create_objects=True)
|
||||||
training = config["training"]
|
training = config["training"]
|
||||||
return training["optimizer"]
|
return training["optimizer"]
|
||||||
|
|
|
@ -198,7 +198,7 @@ def train_cli(
|
||||||
|
|
||||||
if num_workers and num_workers > 1:
|
if num_workers and num_workers > 1:
|
||||||
import ray
|
import ray
|
||||||
ray.init(address="auto")
|
ray.init()
|
||||||
if strategy == "ps":
|
if strategy == "ps":
|
||||||
from spacy.cli.ray_param_server import RayOptimizer
|
from spacy.cli.ray_param_server import RayOptimizer
|
||||||
remote_train = ray.remote(setup_and_train)
|
remote_train = ray.remote(setup_and_train)
|
||||||
|
@ -206,7 +206,8 @@ def train_cli(
|
||||||
msg.info("Enabling GPU with Ray")
|
msg.info("Enabling GPU with Ray")
|
||||||
remote_train = remote_train.options(num_gpus=0.9)
|
remote_train = remote_train.options(num_gpus=0.9)
|
||||||
|
|
||||||
train_args["remote_optimizer"] = RayOptimizer(config_path, use_gpu=use_gpu)
|
train_args["remote_optimizer"] = RayOptimizer(
|
||||||
|
config_path, use_gpu=use_gpu, world_size=num_workers)
|
||||||
ray.get([remote_train.remote(
|
ray.get([remote_train.remote(
|
||||||
use_gpu,
|
use_gpu,
|
||||||
train_args,
|
train_args,
|
||||||
|
|
|
@ -898,11 +898,17 @@ def escape_html(text):
|
||||||
def use_gpu(gpu_id):
|
def use_gpu(gpu_id):
|
||||||
return require_gpu(gpu_id)
|
return require_gpu(gpu_id)
|
||||||
|
|
||||||
|
def gpu_is_available():
|
||||||
|
try:
|
||||||
|
cupy.cuda.runtime.getDeviceCount()
|
||||||
|
return True
|
||||||
|
except cupy.cuda.runtime.CUDARuntimeError:
|
||||||
|
return False
|
||||||
|
|
||||||
def fix_random_seed(seed=0):
|
def fix_random_seed(seed=0):
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
numpy.random.seed(seed)
|
numpy.random.seed(seed)
|
||||||
if cupy is not None:
|
if cupy is not None and gpu_is_available():
|
||||||
cupy.random.seed(seed)
|
cupy.random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user