mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
prevent division by zero in most_similar method (#4488)
This commit is contained in:
parent
a98d1cd58e
commit
d5d55312b2
|
@ -15,6 +15,7 @@ from spacy.util import decaying
|
|||
import numpy
|
||||
import re
|
||||
|
||||
from spacy.vectors import Vectors
|
||||
from ..util import get_doc
|
||||
|
||||
|
||||
|
@ -293,6 +294,13 @@ def test_issue3410():
|
|||
list(phrasematcher.pipe(docs, n_threads=4))
|
||||
|
||||
|
||||
def test_issue3412():
|
||||
data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f")
|
||||
vectors = Vectors(data=data)
|
||||
keys, best_rows, scores = vectors.most_similar(numpy.asarray([[9, 8, 7], [0, 0, 0]], dtype="f"))
|
||||
assert(best_rows[0] == 2)
|
||||
|
||||
|
||||
def test_issue3447():
|
||||
sizes = decaying(10.0, 1.0, 0.5)
|
||||
size = next(sizes)
|
||||
|
|
|
@ -321,14 +321,18 @@ cdef class Vectors:
|
|||
"""
|
||||
xp = get_array_module(self.data)
|
||||
|
||||
vectors = self.data / xp.linalg.norm(self.data, axis=1, keepdims=True)
|
||||
norms = xp.linalg.norm(self.data, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
vectors = self.data / norms
|
||||
|
||||
best_rows = xp.zeros((queries.shape[0], n), dtype='i')
|
||||
scores = xp.zeros((queries.shape[0], n), dtype='f')
|
||||
# Work in batches, to avoid memory problems.
|
||||
for i in range(0, queries.shape[0], batch_size):
|
||||
batch = queries[i : i+batch_size]
|
||||
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)
|
||||
batch_norms = xp.linalg.norm(batch, axis=1, keepdims=True)
|
||||
batch_norms[batch_norms == 0] = 1
|
||||
batch /= batch_norms
|
||||
# batch e.g. (1024, 300)
|
||||
# vectors e.g. (10000, 300)
|
||||
# sims e.g. (1024, 10000)
|
||||
|
|
Loading…
Reference in New Issue
Block a user