mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
Most similar bug (#4446)
* Add batch size indexing * Don't sort if n == 1 * Add test for most similar vectors issue * Change > to >=
This commit is contained in:
parent
4a77d03ff7
commit
e646956176
|
@ -50,6 +50,13 @@ def ngrams_vocab(en_vocab, ngrams_vectors):
|
|||
def data():
|
||||
return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype="f")
|
||||
|
||||
@pytest.fixture
|
||||
def most_similar_vectors_data():
|
||||
return numpy.asarray([[0.0, 1.0, 2.0],
|
||||
[1.0, -2.0, 4.0],
|
||||
[1.0, 1.0, -1.0],
|
||||
[2.0, 3.0, 1.0]], dtype="f")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resize_data():
|
||||
|
@ -127,6 +134,12 @@ def test_set_vector(strings, data):
|
|||
assert list(v[strings[0]]) != list(orig[0])
|
||||
|
||||
|
||||
def test_vectors_most_similar(most_similar_vectors_data):
|
||||
v = Vectors(data=most_similar_vectors_data)
|
||||
_, best_rows, _ = v.most_similar(v.data, batch_size=2, n=2, sort=True)
|
||||
assert all(row[0] == i for i, row in enumerate(best_rows))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", ["apple and orange"])
|
||||
def test_vectors_token_vector(tokenizer_v, vectors, text):
|
||||
doc = tokenizer_v(text)
|
||||
|
@ -284,7 +297,7 @@ def test_vocab_prune_vectors():
|
|||
vocab.set_vector("dog", data[1])
|
||||
vocab.set_vector("kitten", data[2])
|
||||
|
||||
remap = vocab.prune_vectors(2)
|
||||
remap = vocab.prune_vectors(2, batch_size=2)
|
||||
assert list(remap.keys()) == ["kitten"]
|
||||
neighbour, similarity = list(remap.values())[0]
|
||||
assert neighbour == "cat", remap
|
||||
|
|
|
@ -336,8 +336,8 @@ cdef class Vectors:
|
|||
best_rows[i:i+batch_size] = xp.argpartition(sims, -n, axis=1)[:,-n:]
|
||||
scores[i:i+batch_size] = xp.partition(sims, -n, axis=1)[:,-n:]
|
||||
|
||||
if sort:
|
||||
sorted_index = xp.arange(scores.shape[0])[:,None],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
|
||||
if sort and n >= 2:
|
||||
sorted_index = xp.arange(scores.shape[0])[:,None][i:i+batch_size],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
|
||||
scores[i:i+batch_size] = scores[sorted_index]
|
||||
best_rows[i:i+batch_size] = best_rows[sorted_index]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user