mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 00:04:15 +03:00
Use 0-vector for OOV lexemes (#8639)
This commit is contained in:
parent
8233359225
commit
f9fd2889b7
|
@ -3,7 +3,7 @@ from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Mode
|
||||||
from thinc.api import MultiSoftmax, list2array
|
from thinc.api import MultiSoftmax, list2array
|
||||||
from thinc.api import to_categorical, CosineDistance, L2Distance
|
from thinc.api import to_categorical, CosineDistance, L2Distance
|
||||||
|
|
||||||
from ...util import registry
|
from ...util import registry, OOV_RANK
|
||||||
from ...errors import Errors
|
from ...errors import Errors
|
||||||
from ...attrs import ID
|
from ...attrs import ID
|
||||||
|
|
||||||
|
@ -70,6 +70,7 @@ def get_vectors_loss(ops, docs, prediction, distance):
|
||||||
# and look them up all at once. This prevents data copying.
|
# and look them up all at once. This prevents data copying.
|
||||||
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
||||||
target = docs[0].vocab.vectors.data[ids]
|
target = docs[0].vocab.vectors.data[ids]
|
||||||
|
target[ids == OOV_RANK] = 0
|
||||||
d_target, loss = distance(prediction, target)
|
d_target, loss = distance(prediction, target)
|
||||||
return loss, d_target
|
return loss, d_target
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user