mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-07 07:00:34 +03:00
set
This commit is contained in:
parent
8cc49c5a03
commit
d8705e1291
|
@ -84,6 +84,7 @@ class AllreduceOptimizer:
|
||||||
from cupy.cuda import nccl
|
from cupy.cuda import nccl
|
||||||
self.optimizer = _create_optimizer(config_path)
|
self.optimizer = _create_optimizer(config_path)
|
||||||
self.communicator = communicator
|
self.communicator = communicator
|
||||||
|
self.weights_synced = set()
|
||||||
|
|
||||||
def allreduce(self, tensor):
|
def allreduce(self, tensor):
|
||||||
self.communicator.allReduce(
|
self.communicator.allReduce(
|
||||||
|
@ -104,7 +105,11 @@ class AllreduceOptimizer:
|
||||||
*,
|
*,
|
||||||
lr_scale: float = 1.0,
|
lr_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
# weights = self.allreduce(weights)
|
if key not in self.weights_synced:
|
||||||
|
self.weights_synced.add(key)
|
||||||
|
weights = self.allreduce(weights) / self.communicator.size()
|
||||||
|
|
||||||
|
|
||||||
gradient = self.allreduce(gradient)
|
gradient = self.allreduce(gradient)
|
||||||
flat_weights, gradient = self.optimizer(key, weights, gradient, lr_scale=lr_scale)
|
flat_weights, gradient = self.optimizer(key, weights, gradient, lr_scale=lr_scale)
|
||||||
return flat_weights, gradient
|
return flat_weights, gradient
|
||||||
|
|
Loading…
Reference in New Issue
Block a user