Update vector handling in similarity methods (#11013)

Distinguish between vectors that are 0 vs. missing vectors when warning
about missing vectors.

Update `Doc.has_vector` to match `Span.has_vector` and
`Token.has_vector` for cases where the vocab has vectors but none of the
tokens in the container have vectors.
This commit is contained in:
Adriane Boyd 2022-06-28 19:50:47 +02:00 committed by GitHub
parent 1d5cad0b42
commit 24f4908fce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 18 deletions

View File

@ -1,6 +1,7 @@
import pytest import pytest
import numpy import numpy
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.vocab import Vocab
from ..util import get_cosine, add_vecs_to_vocab from ..util import get_cosine, add_vecs_to_vocab
@ -71,7 +72,6 @@ def test_vectors_similarity_DD(vocab, vectors):
def test_vectors_similarity_TD(vocab, vectors): def test_vectors_similarity_TD(vocab, vectors):
[(word1, vec1), (word2, vec2)] = vectors [(word1, vec1), (word2, vec2)] = vectors
doc = Doc(vocab, words=[word1, word2]) doc = Doc(vocab, words=[word1, word2])
with pytest.warns(UserWarning):
assert isinstance(doc.similarity(doc[0]), float) assert isinstance(doc.similarity(doc[0]), float)
assert isinstance(doc[0].similarity(doc), float) assert isinstance(doc[0].similarity(doc), float)
assert doc.similarity(doc[0]) == doc[0].similarity(doc) assert doc.similarity(doc[0]) == doc[0].similarity(doc)
@ -80,9 +80,8 @@ def test_vectors_similarity_TD(vocab, vectors):
def test_vectors_similarity_TS(vocab, vectors): def test_vectors_similarity_TS(vocab, vectors):
[(word1, vec1), (word2, vec2)] = vectors [(word1, vec1), (word2, vec2)] = vectors
doc = Doc(vocab, words=[word1, word2]) doc = Doc(vocab, words=[word1, word2])
with pytest.warns(UserWarning):
assert isinstance(doc[:2].similarity(doc[0]), float) assert isinstance(doc[:2].similarity(doc[0]), float)
assert isinstance(doc[0].similarity(doc[-2]), float) assert isinstance(doc[0].similarity(doc[:2]), float)
assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2]) assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2])
@ -91,3 +90,21 @@ def test_vectors_similarity_DS(vocab, vectors):
doc = Doc(vocab, words=[word1, word2]) doc = Doc(vocab, words=[word1, word2])
assert isinstance(doc.similarity(doc[:2]), float) assert isinstance(doc.similarity(doc[:2]), float)
assert doc.similarity(doc[:2]) == doc[:2].similarity(doc) assert doc.similarity(doc[:2]) == doc[:2].similarity(doc)
def test_vectors_similarity_no_vectors():
vocab = Vocab()
doc1 = Doc(vocab, words=["a", "b"])
doc2 = Doc(vocab, words=["c", "d", "e"])
with pytest.warns(UserWarning):
doc1.similarity(doc2)
with pytest.warns(UserWarning):
doc1.similarity(doc2[1])
with pytest.warns(UserWarning):
doc1.similarity(doc2[:2])
with pytest.warns(UserWarning):
doc2.similarity(doc1)
with pytest.warns(UserWarning):
doc2[1].similarity(doc1)
with pytest.warns(UserWarning):
doc2[:2].similarity(doc1)

View File

@ -318,7 +318,6 @@ def test_vectors_lexeme_doc_similarity(vocab, text):
@pytest.mark.parametrize("text", [["apple", "orange", "juice"]]) @pytest.mark.parametrize("text", [["apple", "orange", "juice"]])
def test_vectors_span_span_similarity(vocab, text): def test_vectors_span_span_similarity(vocab, text):
doc = Doc(vocab, words=text) doc = Doc(vocab, words=text)
with pytest.warns(UserWarning):
assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2]) assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2])
assert -1.0 < doc[0:2].similarity(doc[1:3]) < 1.0 assert -1.0 < doc[0:2].similarity(doc[1:3]) < 1.0
@ -326,7 +325,6 @@ def test_vectors_span_span_similarity(vocab, text):
@pytest.mark.parametrize("text", [["apple", "orange", "juice"]]) @pytest.mark.parametrize("text", [["apple", "orange", "juice"]])
def test_vectors_span_doc_similarity(vocab, text): def test_vectors_span_doc_similarity(vocab, text):
doc = Doc(vocab, words=text) doc = Doc(vocab, words=text)
with pytest.warns(UserWarning):
assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2]) assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2])
assert -1.0 < doc[0:2].similarity(doc) < 1.0 assert -1.0 < doc[0:2].similarity(doc) < 1.0

View File

@ -607,6 +607,7 @@ cdef class Doc:
if self.vocab.vectors.n_keys == 0: if self.vocab.vectors.n_keys == 0:
warnings.warn(Warnings.W007.format(obj="Doc")) warnings.warn(Warnings.W007.format(obj="Doc"))
if self.vector_norm == 0 or other.vector_norm == 0: if self.vector_norm == 0 or other.vector_norm == 0:
if not self.has_vector or not other.has_vector:
warnings.warn(Warnings.W008.format(obj="Doc")) warnings.warn(Warnings.W008.format(obj="Doc"))
return 0.0 return 0.0
vector = self.vector vector = self.vector
@ -627,7 +628,7 @@ cdef class Doc:
if "has_vector" in self.user_hooks: if "has_vector" in self.user_hooks:
return self.user_hooks["has_vector"](self) return self.user_hooks["has_vector"](self)
elif self.vocab.vectors.size: elif self.vocab.vectors.size:
return True return any(token.has_vector for token in self)
elif self.tensor.size: elif self.tensor.size:
return True return True
else: else:

View File

@ -354,6 +354,7 @@ cdef class Span:
if self.vocab.vectors.n_keys == 0: if self.vocab.vectors.n_keys == 0:
warnings.warn(Warnings.W007.format(obj="Span")) warnings.warn(Warnings.W007.format(obj="Span"))
if self.vector_norm == 0.0 or other.vector_norm == 0.0: if self.vector_norm == 0.0 or other.vector_norm == 0.0:
if not self.has_vector or not other.has_vector:
warnings.warn(Warnings.W008.format(obj="Span")) warnings.warn(Warnings.W008.format(obj="Span"))
return 0.0 return 0.0
vector = self.vector vector = self.vector

View File

@ -206,6 +206,7 @@ cdef class Token:
if self.vocab.vectors.n_keys == 0: if self.vocab.vectors.n_keys == 0:
warnings.warn(Warnings.W007.format(obj="Token")) warnings.warn(Warnings.W007.format(obj="Token"))
if self.vector_norm == 0 or other.vector_norm == 0: if self.vector_norm == 0 or other.vector_norm == 0:
if not self.has_vector or not other.has_vector:
warnings.warn(Warnings.W008.format(obj="Token")) warnings.warn(Warnings.W008.format(obj="Token"))
return 0.0 return 0.0
vector = self.vector vector = self.vector