Use 0-vector for OOV lexemes (#8639)

This commit is contained in:
Adriane Boyd 2021-07-13 06:48:12 +02:00 committed by GitHub
parent 8233359225
commit f9fd2889b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,7 +3,7 @@ from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Mode
from thinc.api import MultiSoftmax, list2array from thinc.api import MultiSoftmax, list2array
from thinc.api import to_categorical, CosineDistance, L2Distance from thinc.api import to_categorical, CosineDistance, L2Distance
from ...util import registry from ...util import registry, OOV_RANK
from ...errors import Errors from ...errors import Errors
from ...attrs import ID from ...attrs import ID
@ -70,6 +70,7 @@ def get_vectors_loss(ops, docs, prediction, distance):
# and look them up all at once. This prevents data copying. # and look them up all at once. This prevents data copying.
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs]) ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
target = docs[0].vocab.vectors.data[ids] target = docs[0].vocab.vectors.data[ids]
target[ids == OOV_RANK] = 0
d_target, loss = distance(prediction, target) d_target, loss = distance(prediction, target)
return loss, d_target return loss, d_target