This commit is contained in:
Richard Liaw 2020-06-16 13:19:15 -07:00
parent 8cc49c5a03
commit d8705e1291

View File

@ -84,6 +84,7 @@ class AllreduceOptimizer:
from cupy.cuda import nccl from cupy.cuda import nccl
self.optimizer = _create_optimizer(config_path) self.optimizer = _create_optimizer(config_path)
self.communicator = communicator self.communicator = communicator
self.weights_synced = set()
def allreduce(self, tensor): def allreduce(self, tensor):
self.communicator.allReduce( self.communicator.allReduce(
@ -104,7 +105,11 @@ class AllreduceOptimizer:
*, *,
lr_scale: float = 1.0, lr_scale: float = 1.0,
): ):
# weights = self.allreduce(weights) if key not in self.weights_synced:
self.weights_synced.add(key)
weights = self.allreduce(weights) / self.communicator.size()
gradient = self.allreduce(gradient) gradient = self.allreduce(gradient)
flat_weights, gradient = self.optimizer(key, weights, gradient, lr_scale=lr_scale) flat_weights, gradient = self.optimizer(key, weights, gradient, lr_scale=lr_scale)
return flat_weights, gradient return flat_weights, gradient