mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-07 07:00:34 +03:00
set
This commit is contained in:
parent
8cc49c5a03
commit
d8705e1291
|
@ -84,6 +84,7 @@ class AllreduceOptimizer:
|
|||
from cupy.cuda import nccl
|
||||
self.optimizer = _create_optimizer(config_path)
|
||||
self.communicator = communicator
|
||||
self.weights_synced = set()
|
||||
|
||||
def allreduce(self, tensor):
|
||||
self.communicator.allReduce(
|
||||
|
@ -104,7 +105,11 @@ class AllreduceOptimizer:
|
|||
*,
|
||||
lr_scale: float = 1.0,
|
||||
):
|
||||
# weights = self.allreduce(weights)
|
||||
if key not in self.weights_synced:
|
||||
self.weights_synced.add(key)
|
||||
weights = self.allreduce(weights) / self.communicator.size()
|
||||
|
||||
|
||||
gradient = self.allreduce(gradient)
|
||||
flat_weights, gradient = self.optimizer(key, weights, gradient, lr_scale=lr_scale)
|
||||
return flat_weights, gradient
|
||||
|
|
Loading…
Reference in New Issue
Block a user