mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
* Clip most_similar to range [-1, 1] * Add/fix vectors tests * Fix test
This commit is contained in:
parent
74a19aeb1c
commit
9489c5f6b2
|
@ -141,7 +141,6 @@ def test_vectors_most_similar(most_similar_vectors_data):
|
||||||
assert all(row[0] == i for i, row in enumerate(best_rows))
|
assert all(row[0] == i for i, row in enumerate(best_rows))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
def test_vectors_most_similar_identical():
|
def test_vectors_most_similar_identical():
|
||||||
"""Test that most similar identical vectors are assigned a score of 1.0."""
|
"""Test that most similar identical vectors are assigned a score of 1.0."""
|
||||||
data = numpy.asarray([[4, 2, 2, 2], [4, 2, 2, 2], [1, 1, 1, 1]], dtype="f")
|
data = numpy.asarray([[4, 2, 2, 2], [4, 2, 2, 2], [1, 1, 1, 1]], dtype="f")
|
||||||
|
@ -315,4 +314,4 @@ def test_vocab_prune_vectors():
|
||||||
assert list(remap.keys()) == ["kitten"]
|
assert list(remap.keys()) == ["kitten"]
|
||||||
neighbour, similarity = list(remap.values())[0]
|
neighbour, similarity = list(remap.values())[0]
|
||||||
assert neighbour == "cat", remap
|
assert neighbour == "cat", remap
|
||||||
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-6)
|
assert_allclose(similarity, cosine(data[0], data[2]), atol=1e-4, rtol=1e-3)
|
||||||
|
|
|
@ -344,8 +344,12 @@ cdef class Vectors:
|
||||||
sorted_index = xp.arange(scores.shape[0])[:,None][i:i+batch_size],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
|
sorted_index = xp.arange(scores.shape[0])[:,None][i:i+batch_size],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1]
|
||||||
scores[i:i+batch_size] = scores[sorted_index]
|
scores[i:i+batch_size] = scores[sorted_index]
|
||||||
best_rows[i:i+batch_size] = best_rows[sorted_index]
|
best_rows[i:i+batch_size] = best_rows[sorted_index]
|
||||||
|
|
||||||
xp = get_array_module(self.data)
|
xp = get_array_module(self.data)
|
||||||
|
# Round values really close to 1 or -1
|
||||||
|
scores = xp.around(scores, decimals=4, out=scores)
|
||||||
|
# Account for numerical error we want to return in range -1, 1
|
||||||
|
scores = xp.clip(scores, a_min=-1, a_max=1, out=scores)
|
||||||
row2key = {row: key for key, row in self.key2row.items()}
|
row2key = {row: key for key, row in self.key2row.items()}
|
||||||
keys = xp.asarray(
|
keys = xp.asarray(
|
||||||
[[row2key[row] for row in best_rows[i] if row in row2key]
|
[[row2key[row] for row in best_rows[i] if row in row2key]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user