Rescale gradients for mlm

This commit is contained in:
Matthw Honnibal 2019-10-24 17:35:37 +02:00
parent 7d81d17ce5
commit 73b1f651d4

View File

@ -966,6 +966,8 @@ def masked_language_model(vocab, model, mask_prob=0.15):
def mlm_backward(d_output, sgd=None):
d_output *= 1 - mask
# Rescale gradient for number of instances.
d_output *= mask.size - mask.sum()
return backprop(d_output, sgd=sgd)
return output, mlm_backward