diff --git a/spacy/_ml.py b/spacy/_ml.py index 3fafaaa09..1bc5f30cb 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -840,6 +840,8 @@ def masked_language_model(vocab, model, mask_prob=0.15): def mlm_backward(d_output, sgd=None): d_output *= 1 - mask + # Rescale gradient for number of instances. + d_output *= mask.size - mask.sum() return backprop(d_output, sgd=sgd) return output, mlm_backward