mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-14 13:47:13 +03:00
Rescale gradients for mlm
This commit is contained in:
parent
7d81d17ce5
commit
73b1f651d4
|
@ -966,6 +966,8 @@ def masked_language_model(vocab, model, mask_prob=0.15):
|
|||
|
||||
def mlm_backward(d_output, sgd=None):
|
||||
d_output *= 1 - mask
|
||||
# Rescale gradient for number of instances.
|
||||
d_output *= mask.size - mask.sum()
|
||||
return backprop(d_output, sgd=sgd)
|
||||
|
||||
return output, mlm_backward
|
||||
|
|
Loading…
Reference in New Issue
Block a user