From d5d55312b2c5184ec96f8f073eab1b045b441fe4 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Mon, 21 Oct 2019 12:04:46 +0200 Subject: [PATCH] prevent division by zero in most_similar method (#4488) --- spacy/tests/regression/test_issue3001-3500.py | 8 ++++++++ spacy/vectors.pyx | 8 ++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/spacy/tests/regression/test_issue3001-3500.py b/spacy/tests/regression/test_issue3001-3500.py index def95ac73..8ed243051 100644 --- a/spacy/tests/regression/test_issue3001-3500.py +++ b/spacy/tests/regression/test_issue3001-3500.py @@ -15,6 +15,7 @@ from spacy.util import decaying import numpy import re +from spacy.vectors import Vectors from ..util import get_doc @@ -293,6 +294,13 @@ def test_issue3410(): list(phrasematcher.pipe(docs, n_threads=4)) +def test_issue3412(): + data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f") + vectors = Vectors(data=data) + keys, best_rows, scores = vectors.most_similar(numpy.asarray([[9, 8, 7], [0, 0, 0]], dtype="f")) + assert(best_rows[0] == 2) + + def test_issue3447(): sizes = decaying(10.0, 1.0, 0.5) size = next(sizes) diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index 6ad1202de..0f015521a 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -321,14 +321,18 @@ cdef class Vectors: """ xp = get_array_module(self.data) - vectors = self.data / xp.linalg.norm(self.data, axis=1, keepdims=True) + norms = xp.linalg.norm(self.data, axis=1, keepdims=True) + norms[norms == 0] = 1 + vectors = self.data / norms best_rows = xp.zeros((queries.shape[0], n), dtype='i') scores = xp.zeros((queries.shape[0], n), dtype='f') # Work in batches, to avoid memory problems. for i in range(0, queries.shape[0], batch_size): batch = queries[i : i+batch_size] - batch /= xp.linalg.norm(batch, axis=1, keepdims=True) + batch_norms = xp.linalg.norm(batch, axis=1, keepdims=True) + batch_norms[batch_norms == 0] = 1 + batch /= batch_norms # batch e.g. (1024, 300) # vectors e.g. (10000, 300) # sims e.g. (1024, 10000)