prevent division by zero in most_similar method (#4488)

This commit is contained in:
Sofie Van Landeghem 2019-10-21 12:04:46 +02:00 committed by Matthew Honnibal
parent a98d1cd58e
commit d5d55312b2
2 changed files with 14 additions and 2 deletions

View File

@ -15,6 +15,7 @@ from spacy.util import decaying
import numpy import numpy
import re import re
from spacy.vectors import Vectors
from ..util import get_doc from ..util import get_doc
@ -293,6 +294,13 @@ def test_issue3410():
list(phrasematcher.pipe(docs, n_threads=4)) list(phrasematcher.pipe(docs, n_threads=4))
def test_issue3412():
data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f")
vectors = Vectors(data=data)
keys, best_rows, scores = vectors.most_similar(numpy.asarray([[9, 8, 7], [0, 0, 0]], dtype="f"))
assert(best_rows[0] == 2)
def test_issue3447(): def test_issue3447():
sizes = decaying(10.0, 1.0, 0.5) sizes = decaying(10.0, 1.0, 0.5)
size = next(sizes) size = next(sizes)

View File

@ -321,14 +321,18 @@ cdef class Vectors:
""" """
xp = get_array_module(self.data) xp = get_array_module(self.data)
vectors = self.data / xp.linalg.norm(self.data, axis=1, keepdims=True) norms = xp.linalg.norm(self.data, axis=1, keepdims=True)
norms[norms == 0] = 1
vectors = self.data / norms
best_rows = xp.zeros((queries.shape[0], n), dtype='i') best_rows = xp.zeros((queries.shape[0], n), dtype='i')
scores = xp.zeros((queries.shape[0], n), dtype='f') scores = xp.zeros((queries.shape[0], n), dtype='f')
# Work in batches, to avoid memory problems. # Work in batches, to avoid memory problems.
for i in range(0, queries.shape[0], batch_size): for i in range(0, queries.shape[0], batch_size):
batch = queries[i : i+batch_size] batch = queries[i : i+batch_size]
batch /= xp.linalg.norm(batch, axis=1, keepdims=True) batch_norms = xp.linalg.norm(batch, axis=1, keepdims=True)
batch_norms[batch_norms == 0] = 1
batch /= batch_norms
# batch e.g. (1024, 300) # batch e.g. (1024, 300)
# vectors e.g. (10000, 300) # vectors e.g. (10000, 300)
# sims e.g. (1024, 10000) # sims e.g. (1024, 10000)