mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
prevent division by zero in most_similar method (#4488)
This commit is contained in:
parent
a98d1cd58e
commit
d5d55312b2
|
@ -15,6 +15,7 @@ from spacy.util import decaying
|
||||||
import numpy
|
import numpy
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from spacy.vectors import Vectors
|
||||||
from ..util import get_doc
|
from ..util import get_doc
|
||||||
|
|
||||||
|
|
||||||
|
@ -293,6 +294,13 @@ def test_issue3410():
|
||||||
list(phrasematcher.pipe(docs, n_threads=4))
|
list(phrasematcher.pipe(docs, n_threads=4))
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue3412():
|
||||||
|
data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f")
|
||||||
|
vectors = Vectors(data=data)
|
||||||
|
keys, best_rows, scores = vectors.most_similar(numpy.asarray([[9, 8, 7], [0, 0, 0]], dtype="f"))
|
||||||
|
assert(best_rows[0] == 2)
|
||||||
|
|
||||||
|
|
||||||
def test_issue3447():
|
def test_issue3447():
|
||||||
sizes = decaying(10.0, 1.0, 0.5)
|
sizes = decaying(10.0, 1.0, 0.5)
|
||||||
size = next(sizes)
|
size = next(sizes)
|
||||||
|
|
|
@ -321,14 +321,18 @@ cdef class Vectors:
|
||||||
"""
|
"""
|
||||||
xp = get_array_module(self.data)
|
xp = get_array_module(self.data)
|
||||||
|
|
||||||
vectors = self.data / xp.linalg.norm(self.data, axis=1, keepdims=True)
|
norms = xp.linalg.norm(self.data, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
vectors = self.data / norms
|
||||||
|
|
||||||
best_rows = xp.zeros((queries.shape[0], n), dtype='i')
|
best_rows = xp.zeros((queries.shape[0], n), dtype='i')
|
||||||
scores = xp.zeros((queries.shape[0], n), dtype='f')
|
scores = xp.zeros((queries.shape[0], n), dtype='f')
|
||||||
# Work in batches, to avoid memory problems.
|
# Work in batches, to avoid memory problems.
|
||||||
for i in range(0, queries.shape[0], batch_size):
|
for i in range(0, queries.shape[0], batch_size):
|
||||||
batch = queries[i : i+batch_size]
|
batch = queries[i : i+batch_size]
|
||||||
batch /= xp.linalg.norm(batch, axis=1, keepdims=True)
|
batch_norms = xp.linalg.norm(batch, axis=1, keepdims=True)
|
||||||
|
batch_norms[batch_norms == 0] = 1
|
||||||
|
batch /= batch_norms
|
||||||
# batch e.g. (1024, 300)
|
# batch e.g. (1024, 300)
|
||||||
# vectors e.g. (10000, 300)
|
# vectors e.g. (10000, 300)
|
||||||
# sims e.g. (1024, 10000)
|
# sims e.g. (1024, 10000)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user