mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Most similar bug (#4446)
* Add batch size indexing * Don't sort if n == 1 * Add test for most similar vectors issue * Change > to >=
This commit is contained in:
parent
4a77d03ff7
commit
e646956176
|
@ -50,6 +50,13 @@ def ngrams_vocab(en_vocab, ngrams_vectors):
|
||||||
def data():
|
def data():
|
||||||
return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype="f")
|
return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype="f")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def most_similar_vectors_data():
|
||||||
|
return numpy.asarray([[0.0, 1.0, 2.0],
|
||||||
|
[1.0, -2.0, 4.0],
|
||||||
|
[1.0, 1.0, -1.0],
|
||||||
|
[2.0, 3.0, 1.0]], dtype="f")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def resize_data():
|
def resize_data():
|
||||||
|
@ -127,6 +134,12 @@ def test_set_vector(strings, data):
|
||||||
assert list(v[strings[0]]) != list(orig[0])
|
assert list(v[strings[0]]) != list(orig[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_vectors_most_similar(most_similar_vectors_data):
|
||||||
|
v = Vectors(data=most_similar_vectors_data)
|
||||||
|
_, best_rows, _ = v.most_similar(v.data, batch_size=2, n=2, sort=True)
|
||||||
|
assert all(row[0] == i for i, row in enumerate(best_rows))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("text", ["apple and orange"])
|
@pytest.mark.parametrize("text", ["apple and orange"])
|
||||||
def test_vectors_token_vector(tokenizer_v, vectors, text):
|
def test_vectors_token_vector(tokenizer_v, vectors, text):
|
||||||
doc = tokenizer_v(text)
|
doc = tokenizer_v(text)
|
||||||
|
@ -284,7 +297,7 @@ def test_vocab_prune_vectors():
|
||||||
vocab.set_vector("dog", data[1])
|
vocab.set_vector("dog", data[1])
|
||||||
vocab.set_vector("kitten", data[2])
|
vocab.set_vector("kitten", data[2])
|
||||||
|
|
||||||
remap = vocab.prune_vectors(2)
|
remap = vocab.prune_vectors(2, batch_size=2)
|
||||||
assert list(remap.keys()) == ["kitten"]
|
assert list(remap.keys()) == ["kitten"]
|
||||||
neighbour, similarity = list(remap.values())[0]
|
neighbour, similarity = list(remap.values())[0]
|
||||||
assert neighbour == "cat", remap
|
assert neighbour == "cat", remap
|
||||||
|
|
|
@ -336,8 +336,8 @@ cdef class Vectors:
|
||||||
best_rows[i:i+batch_size] = xp.argpartition(sims, -n, axis=1)[:,-n:]
|
best_rows[i:i+batch_size] = xp.argpartition(sims, -n, axis=1)[:,-n:]
|
||||||
scores[i:i+batch_size] = xp.partition(sims, -n, axis=1)[:,-n:]
|
scores[i:i+batch_size] = xp.partition(sims, -n, axis=1)[:,-n:]
|
||||||
|
|
||||||
if sort:
|
if sort and n >= 2:
|
||||||
sorted_index = xp.arange(scores.shape[0])[:,None],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
|
sorted_index = xp.arange(scores.shape[0])[:,None][i:i+batch_size],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
|
||||||
scores[i:i+batch_size] = scores[sorted_index]
|
scores[i:i+batch_size] = scores[sorted_index]
|
||||||
best_rows[i:i+batch_size] = best_rows[sorted_index]
|
best_rows[i:i+batch_size] = best_rows[sorted_index]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user