Rescale gradients for mlm

This commit is contained in:
Matthw Honnibal 2019-10-24 17:35:37 +02:00 committed by Matthew Honnibal
parent f597992411
commit fa8ad11158

View File

@ -840,6 +840,8 @@ def masked_language_model(vocab, model, mask_prob=0.15):
def mlm_backward(d_output, sgd=None):
d_output *= 1 - mask
# Rescale gradient for number of instances.
d_output *= mask.size - mask.sum()
return backprop(d_output, sgd=sgd)
return output, mlm_backward