mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
Fix most_similar for vectors with unused rows (#5348)
* Fix most_similar for vectors with unused rows Address issues related to the unused rows in the vector table and `most_similar`: * Update `most_similar()` to search only through rows that are in use according to `key2row`. * Raise an error when `most_similar(n=n)` is larger than the number of vectors in the table. * Set and restore `_unset` correctly when vectors are added or deserialized so that new vectors are added in the correct row. * Set data and keys to the same length in `Vocab.prune_vectors()` to avoid spurious entries in `key2row`. * Fix regression test using `most_similar` Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
parent
70da1fd2d6
commit
40e65d6f63
|
@ -564,6 +564,8 @@ class Errors(object):
|
||||||
E196 = ("Refusing to write to token.is_sent_end. Sentence boundaries can "
|
E196 = ("Refusing to write to token.is_sent_end. Sentence boundaries can "
|
||||||
"only be fixed with token.is_sent_start.")
|
"only be fixed with token.is_sent_start.")
|
||||||
E197 = ("Row out of bounds, unable to add row {row} for key {key}.")
|
E197 = ("Row out of bounds, unable to add row {row} for key {key}.")
|
||||||
|
E198 = ("Unable to return {n} most similar vectors for the current vectors "
|
||||||
|
"table, which contains {n_rows} vectors.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -295,7 +295,7 @@ def test_issue3410():
|
||||||
|
|
||||||
def test_issue3412():
|
def test_issue3412():
|
||||||
data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f")
|
data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f")
|
||||||
vectors = Vectors(data=data)
|
vectors = Vectors(data=data, keys=["A", "B", "C"])
|
||||||
keys, best_rows, scores = vectors.most_similar(
|
keys, best_rows, scores = vectors.most_similar(
|
||||||
numpy.asarray([[9, 8, 7], [0, 0, 0]], dtype="f")
|
numpy.asarray([[9, 8, 7], [0, 0, 0]], dtype="f")
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.testing import assert_allclose
|
from numpy.testing import assert_allclose, assert_equal
|
||||||
from spacy._ml import cosine
|
from spacy._ml import cosine
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.vectors import Vectors
|
from spacy.vectors import Vectors
|
||||||
|
@ -11,7 +11,7 @@ from spacy.tokenizer import Tokenizer
|
||||||
from spacy.strings import hash_string
|
from spacy.strings import hash_string
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
from ..util import add_vecs_to_vocab
|
from ..util import add_vecs_to_vocab, make_tempdir
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -59,6 +59,11 @@ def most_similar_vectors_data():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def most_similar_vectors_keys():
|
||||||
|
return ["a", "b", "c", "d"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def resize_data():
|
def resize_data():
|
||||||
return numpy.asarray([[0.0, 1.0], [2.0, 3.0]], dtype="f")
|
return numpy.asarray([[0.0, 1.0], [2.0, 3.0]], dtype="f")
|
||||||
|
@ -146,11 +151,14 @@ def test_set_vector(strings, data):
|
||||||
assert list(v[strings[0]]) != list(orig[0])
|
assert list(v[strings[0]]) != list(orig[0])
|
||||||
|
|
||||||
|
|
||||||
def test_vectors_most_similar(most_similar_vectors_data):
|
def test_vectors_most_similar(most_similar_vectors_data, most_similar_vectors_keys):
|
||||||
v = Vectors(data=most_similar_vectors_data)
|
v = Vectors(data=most_similar_vectors_data, keys=most_similar_vectors_keys)
|
||||||
_, best_rows, _ = v.most_similar(v.data, batch_size=2, n=2, sort=True)
|
_, best_rows, _ = v.most_similar(v.data, batch_size=2, n=2, sort=True)
|
||||||
assert all(row[0] == i for i, row in enumerate(best_rows))
|
assert all(row[0] == i for i, row in enumerate(best_rows))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
v.most_similar(v.data, batch_size=2, n=10, sort=True)
|
||||||
|
|
||||||
|
|
||||||
def test_vectors_most_similar_identical():
|
def test_vectors_most_similar_identical():
|
||||||
"""Test that most similar identical vectors are assigned a score of 1.0."""
|
"""Test that most similar identical vectors are assigned a score of 1.0."""
|
||||||
|
@ -331,6 +339,33 @@ def test_vocab_prune_vectors():
|
||||||
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-4, rtol=1e-3)
|
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-4, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vectors_serialize():
|
||||||
|
data = numpy.asarray([[4, 2, 2, 2], [4, 2, 2, 2], [1, 1, 1, 1]], dtype="f")
|
||||||
|
v = Vectors(data=data, keys=["A", "B", "C"])
|
||||||
|
b = v.to_bytes()
|
||||||
|
v_r = Vectors()
|
||||||
|
v_r.from_bytes(b)
|
||||||
|
assert_equal(v.data, v_r.data)
|
||||||
|
assert v.key2row == v_r.key2row
|
||||||
|
v.resize((5, 4))
|
||||||
|
v_r.resize((5, 4))
|
||||||
|
row = v.add("D", vector=numpy.asarray([1, 2, 3, 4], dtype="f"))
|
||||||
|
row_r = v_r.add("D", vector=numpy.asarray([1, 2, 3, 4], dtype="f"))
|
||||||
|
assert row == row_r
|
||||||
|
assert_equal(v.data, v_r.data)
|
||||||
|
assert v.is_full == v_r.is_full
|
||||||
|
with make_tempdir() as d:
|
||||||
|
v.to_disk(d)
|
||||||
|
v_r.from_disk(d)
|
||||||
|
assert_equal(v.data, v_r.data)
|
||||||
|
assert v.key2row == v_r.key2row
|
||||||
|
v.resize((5, 4))
|
||||||
|
v_r.resize((5, 4))
|
||||||
|
row = v.add("D", vector=numpy.asarray([10, 20, 30, 40], dtype="f"))
|
||||||
|
row_r = v_r.add("D", vector=numpy.asarray([10, 20, 30, 40], dtype="f"))
|
||||||
|
assert row == row_r
|
||||||
|
assert_equal(v.data, v_r.data)
|
||||||
|
|
||||||
def test_vector_is_oov():
|
def test_vector_is_oov():
|
||||||
vocab = Vocab(vectors_name="test_vocab_is_oov")
|
vocab = Vocab(vectors_name="test_vocab_is_oov")
|
||||||
data = numpy.ndarray((5, 3), dtype="f")
|
data = numpy.ndarray((5, 3), dtype="f")
|
||||||
|
@ -340,4 +375,4 @@ def test_vector_is_oov():
|
||||||
vocab.set_vector("dog", data[1])
|
vocab.set_vector("dog", data[1])
|
||||||
assert vocab["cat"].is_oov is True
|
assert vocab["cat"].is_oov is True
|
||||||
assert vocab["dog"].is_oov is True
|
assert vocab["dog"].is_oov is True
|
||||||
assert vocab["hamster"].is_oov is False
|
assert vocab["hamster"].is_oov is False
|
|
@ -212,8 +212,7 @@ cdef class Vectors:
|
||||||
copy_shape = (min(shape[0], self.data.shape[0]), min(shape[1], self.data.shape[1]))
|
copy_shape = (min(shape[0], self.data.shape[0]), min(shape[1], self.data.shape[1]))
|
||||||
resized_array[:copy_shape[0], :copy_shape[1]] = self.data[:copy_shape[0], :copy_shape[1]]
|
resized_array[:copy_shape[0], :copy_shape[1]] = self.data[:copy_shape[0], :copy_shape[1]]
|
||||||
self.data = resized_array
|
self.data = resized_array
|
||||||
filled = {row for row in self.key2row.values()}
|
self._sync_unset()
|
||||||
self._unset = cppset[int]({row for row in range(shape[0]) if row not in filled})
|
|
||||||
removed_items = []
|
removed_items = []
|
||||||
for key, row in list(self.key2row.items()):
|
for key, row in list(self.key2row.items()):
|
||||||
if row >= shape[0]:
|
if row >= shape[0]:
|
||||||
|
@ -310,8 +309,8 @@ cdef class Vectors:
|
||||||
raise ValueError(Errors.E197.format(row=row, key=key))
|
raise ValueError(Errors.E197.format(row=row, key=key))
|
||||||
if vector is not None:
|
if vector is not None:
|
||||||
self.data[row] = vector
|
self.data[row] = vector
|
||||||
if self._unset.count(row):
|
if self._unset.count(row):
|
||||||
self._unset.erase(self._unset.find(row))
|
self._unset.erase(self._unset.find(row))
|
||||||
return row
|
return row
|
||||||
|
|
||||||
def most_similar(self, queries, *, batch_size=1024, n=1, sort=True):
|
def most_similar(self, queries, *, batch_size=1024, n=1, sort=True):
|
||||||
|
@ -330,11 +329,14 @@ cdef class Vectors:
|
||||||
RETURNS (tuple): The most similar entries as a `(keys, best_rows, scores)`
|
RETURNS (tuple): The most similar entries as a `(keys, best_rows, scores)`
|
||||||
tuple.
|
tuple.
|
||||||
"""
|
"""
|
||||||
|
filled = sorted(list({row for row in self.key2row.values()}))
|
||||||
|
if len(filled) < n:
|
||||||
|
raise ValueError(Errors.E198.format(n=n, n_rows=len(filled)))
|
||||||
xp = get_array_module(self.data)
|
xp = get_array_module(self.data)
|
||||||
|
|
||||||
norms = xp.linalg.norm(self.data, axis=1, keepdims=True)
|
norms = xp.linalg.norm(self.data[filled], axis=1, keepdims=True)
|
||||||
norms[norms == 0] = 1
|
norms[norms == 0] = 1
|
||||||
vectors = self.data / norms
|
vectors = self.data[filled] / norms
|
||||||
|
|
||||||
best_rows = xp.zeros((queries.shape[0], n), dtype='i')
|
best_rows = xp.zeros((queries.shape[0], n), dtype='i')
|
||||||
scores = xp.zeros((queries.shape[0], n), dtype='f')
|
scores = xp.zeros((queries.shape[0], n), dtype='f')
|
||||||
|
@ -356,7 +358,8 @@ cdef class Vectors:
|
||||||
scores[i:i+batch_size] = scores[sorted_index]
|
scores[i:i+batch_size] = scores[sorted_index]
|
||||||
best_rows[i:i+batch_size] = best_rows[sorted_index]
|
best_rows[i:i+batch_size] = best_rows[sorted_index]
|
||||||
|
|
||||||
xp = get_array_module(self.data)
|
for i, j in numpy.ndindex(best_rows.shape):
|
||||||
|
best_rows[i, j] = filled[best_rows[i, j]]
|
||||||
# Round values really close to 1 or -1
|
# Round values really close to 1 or -1
|
||||||
scores = xp.around(scores, decimals=4, out=scores)
|
scores = xp.around(scores, decimals=4, out=scores)
|
||||||
# Account for numerical error we want to return in range -1, 1
|
# Account for numerical error we want to return in range -1, 1
|
||||||
|
@ -419,6 +422,7 @@ cdef class Vectors:
|
||||||
("vectors", load_vectors),
|
("vectors", load_vectors),
|
||||||
))
|
))
|
||||||
util.from_disk(path, serializers, [])
|
util.from_disk(path, serializers, [])
|
||||||
|
self._sync_unset()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_bytes(self, **kwargs):
|
def to_bytes(self, **kwargs):
|
||||||
|
@ -461,4 +465,9 @@ cdef class Vectors:
|
||||||
("vectors", deserialize_weights)
|
("vectors", deserialize_weights)
|
||||||
))
|
))
|
||||||
util.from_bytes(data, deserializers, [])
|
util.from_bytes(data, deserializers, [])
|
||||||
|
self._sync_unset()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _sync_unset(self):
|
||||||
|
filled = {row for row in self.key2row.values()}
|
||||||
|
self._unset = cppset[int]({row for row in range(self.data.shape[0]) if row not in filled})
|
||||||
|
|
Loading…
Reference in New Issue
Block a user