From 24f4908fce4740130fc5355f28e9aa87cadd9817 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 28 Jun 2022 19:50:47 +0200 Subject: [PATCH] Update vector handling in similarity methods (#11013) Distinguish between vectors that are 0 vs. missing vectors when warning about missing vectors. Update `Doc.has_vector` to match `Span.has_vector` and `Token.has_vector` for cases where the vocab has vectors but none of the tokens in the container have vectors. --- spacy/tests/vocab_vectors/test_similarity.py | 33 +++++++++++++++----- spacy/tests/vocab_vectors/test_vectors.py | 10 +++--- spacy/tokens/doc.pyx | 5 +-- spacy/tokens/span.pyx | 3 +- spacy/tokens/token.pyx | 3 +- 5 files changed, 36 insertions(+), 18 deletions(-) diff --git a/spacy/tests/vocab_vectors/test_similarity.py b/spacy/tests/vocab_vectors/test_similarity.py index 47cd1f060..1efcdd81e 100644 --- a/spacy/tests/vocab_vectors/test_similarity.py +++ b/spacy/tests/vocab_vectors/test_similarity.py @@ -1,6 +1,7 @@ import pytest import numpy from spacy.tokens import Doc +from spacy.vocab import Vocab from ..util import get_cosine, add_vecs_to_vocab @@ -71,19 +72,17 @@ def test_vectors_similarity_DD(vocab, vectors): def test_vectors_similarity_TD(vocab, vectors): [(word1, vec1), (word2, vec2)] = vectors doc = Doc(vocab, words=[word1, word2]) - with pytest.warns(UserWarning): - assert isinstance(doc.similarity(doc[0]), float) - assert isinstance(doc[0].similarity(doc), float) - assert doc.similarity(doc[0]) == doc[0].similarity(doc) + assert isinstance(doc.similarity(doc[0]), float) + assert isinstance(doc[0].similarity(doc), float) + assert doc.similarity(doc[0]) == doc[0].similarity(doc) def test_vectors_similarity_TS(vocab, vectors): [(word1, vec1), (word2, vec2)] = vectors doc = Doc(vocab, words=[word1, word2]) - with pytest.warns(UserWarning): - assert isinstance(doc[:2].similarity(doc[0]), float) - assert isinstance(doc[0].similarity(doc[-2]), float) - assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2]) + assert isinstance(doc[:2].similarity(doc[0]), float) + assert isinstance(doc[0].similarity(doc[:2]), float) + assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2]) def test_vectors_similarity_DS(vocab, vectors): @@ -91,3 +90,21 @@ def test_vectors_similarity_DS(vocab, vectors): doc = Doc(vocab, words=[word1, word2]) assert isinstance(doc.similarity(doc[:2]), float) assert doc.similarity(doc[:2]) == doc[:2].similarity(doc) + + +def test_vectors_similarity_no_vectors(): + vocab = Vocab() + doc1 = Doc(vocab, words=["a", "b"]) + doc2 = Doc(vocab, words=["c", "d", "e"]) + with pytest.warns(UserWarning): + doc1.similarity(doc2) + with pytest.warns(UserWarning): + doc1.similarity(doc2[1]) + with pytest.warns(UserWarning): + doc1.similarity(doc2[:2]) + with pytest.warns(UserWarning): + doc2.similarity(doc1) + with pytest.warns(UserWarning): + doc2[1].similarity(doc1) + with pytest.warns(UserWarning): + doc2[:2].similarity(doc1) diff --git a/spacy/tests/vocab_vectors/test_vectors.py b/spacy/tests/vocab_vectors/test_vectors.py index e3ad206f4..dd2cfc596 100644 --- a/spacy/tests/vocab_vectors/test_vectors.py +++ b/spacy/tests/vocab_vectors/test_vectors.py @@ -318,17 +318,15 @@ def test_vectors_lexeme_doc_similarity(vocab, text): @pytest.mark.parametrize("text", [["apple", "orange", "juice"]]) def test_vectors_span_span_similarity(vocab, text): doc = Doc(vocab, words=text) - with pytest.warns(UserWarning): - assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2]) - assert -1.0 < doc[0:2].similarity(doc[1:3]) < 1.0 + assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2]) + assert -1.0 < doc[0:2].similarity(doc[1:3]) < 1.0 @pytest.mark.parametrize("text", [["apple", "orange", "juice"]]) def test_vectors_span_doc_similarity(vocab, text): doc = Doc(vocab, words=text) - with pytest.warns(UserWarning): - assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2]) - assert -1.0 < doc[0:2].similarity(doc) < 1.0 + assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2]) + assert -1.0 < doc[0:2].similarity(doc) < 1.0 @pytest.mark.parametrize( diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index e38de02b4..d9a104ac8 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -607,7 +607,8 @@ cdef class Doc: if self.vocab.vectors.n_keys == 0: warnings.warn(Warnings.W007.format(obj="Doc")) if self.vector_norm == 0 or other.vector_norm == 0: - warnings.warn(Warnings.W008.format(obj="Doc")) + if not self.has_vector or not other.has_vector: + warnings.warn(Warnings.W008.format(obj="Doc")) return 0.0 vector = self.vector xp = get_array_module(vector) @@ -627,7 +628,7 @@ cdef class Doc: if "has_vector" in self.user_hooks: return self.user_hooks["has_vector"](self) elif self.vocab.vectors.size: - return True + return any(token.has_vector for token in self) elif self.tensor.size: return True else: diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index ab888ae95..c3495f497 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -354,7 +354,8 @@ cdef class Span: if self.vocab.vectors.n_keys == 0: warnings.warn(Warnings.W007.format(obj="Span")) if self.vector_norm == 0.0 or other.vector_norm == 0.0: - warnings.warn(Warnings.W008.format(obj="Span")) + if not self.has_vector or not other.has_vector: + warnings.warn(Warnings.W008.format(obj="Span")) return 0.0 vector = self.vector xp = get_array_module(vector) diff --git a/spacy/tokens/token.pyx b/spacy/tokens/token.pyx index d14930348..7fff6b162 100644 --- a/spacy/tokens/token.pyx +++ b/spacy/tokens/token.pyx @@ -206,7 +206,8 @@ cdef class Token: if self.vocab.vectors.n_keys == 0: warnings.warn(Warnings.W007.format(obj="Token")) if self.vector_norm == 0 or other.vector_norm == 0: - warnings.warn(Warnings.W008.format(obj="Token")) + if not self.has_vector or not other.has_vector: + warnings.warn(Warnings.W008.format(obj="Token")) return 0.0 vector = self.vector xp = get_array_module(vector)