diff --git a/spacy/cli/ray_utils.py b/spacy/cli/ray_utils.py index 8bfd243a1..fd5efffee 100644 --- a/spacy/cli/ray_utils.py +++ b/spacy/cli/ray_utils.py @@ -84,6 +84,7 @@ class AllreduceOptimizer: from cupy.cuda import nccl self.optimizer = _create_optimizer(config_path) self.communicator = communicator + self.weights_synced = set() def allreduce(self, tensor): self.communicator.allReduce( @@ -104,7 +105,11 @@ class AllreduceOptimizer: *, lr_scale: float = 1.0, ): - # weights = self.allreduce(weights) + if key not in self.weights_synced: + self.weights_synced.add(key) + weights = self.allreduce(weights) / self.communicator.size() + + gradient = self.allreduce(gradient) flat_weights, gradient = self.optimizer(key, weights, gradient, lr_scale=lr_scale) return flat_weights, gradient