mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Update vector handling in similarity methods (#11013)
Distinguish between vectors that are 0 vs. missing vectors when warning about missing vectors. Update `Doc.has_vector` to match `Span.has_vector` and `Token.has_vector` for cases where the vocab has vectors but none of the tokens in the container have vectors.
This commit is contained in:
parent
1d5cad0b42
commit
24f4908fce
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
import numpy
|
||||
from spacy.tokens import Doc
|
||||
from spacy.vocab import Vocab
|
||||
|
||||
from ..util import get_cosine, add_vecs_to_vocab
|
||||
|
||||
|
@ -71,19 +72,17 @@ def test_vectors_similarity_DD(vocab, vectors):
|
|||
def test_vectors_similarity_TD(vocab, vectors):
|
||||
[(word1, vec1), (word2, vec2)] = vectors
|
||||
doc = Doc(vocab, words=[word1, word2])
|
||||
with pytest.warns(UserWarning):
|
||||
assert isinstance(doc.similarity(doc[0]), float)
|
||||
assert isinstance(doc[0].similarity(doc), float)
|
||||
assert doc.similarity(doc[0]) == doc[0].similarity(doc)
|
||||
assert isinstance(doc.similarity(doc[0]), float)
|
||||
assert isinstance(doc[0].similarity(doc), float)
|
||||
assert doc.similarity(doc[0]) == doc[0].similarity(doc)
|
||||
|
||||
|
||||
def test_vectors_similarity_TS(vocab, vectors):
|
||||
[(word1, vec1), (word2, vec2)] = vectors
|
||||
doc = Doc(vocab, words=[word1, word2])
|
||||
with pytest.warns(UserWarning):
|
||||
assert isinstance(doc[:2].similarity(doc[0]), float)
|
||||
assert isinstance(doc[0].similarity(doc[-2]), float)
|
||||
assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2])
|
||||
assert isinstance(doc[:2].similarity(doc[0]), float)
|
||||
assert isinstance(doc[0].similarity(doc[:2]), float)
|
||||
assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2])
|
||||
|
||||
|
||||
def test_vectors_similarity_DS(vocab, vectors):
|
||||
|
@ -91,3 +90,21 @@ def test_vectors_similarity_DS(vocab, vectors):
|
|||
doc = Doc(vocab, words=[word1, word2])
|
||||
assert isinstance(doc.similarity(doc[:2]), float)
|
||||
assert doc.similarity(doc[:2]) == doc[:2].similarity(doc)
|
||||
|
||||
|
||||
def test_vectors_similarity_no_vectors():
|
||||
vocab = Vocab()
|
||||
doc1 = Doc(vocab, words=["a", "b"])
|
||||
doc2 = Doc(vocab, words=["c", "d", "e"])
|
||||
with pytest.warns(UserWarning):
|
||||
doc1.similarity(doc2)
|
||||
with pytest.warns(UserWarning):
|
||||
doc1.similarity(doc2[1])
|
||||
with pytest.warns(UserWarning):
|
||||
doc1.similarity(doc2[:2])
|
||||
with pytest.warns(UserWarning):
|
||||
doc2.similarity(doc1)
|
||||
with pytest.warns(UserWarning):
|
||||
doc2[1].similarity(doc1)
|
||||
with pytest.warns(UserWarning):
|
||||
doc2[:2].similarity(doc1)
|
||||
|
|
|
@ -318,17 +318,15 @@ def test_vectors_lexeme_doc_similarity(vocab, text):
|
|||
@pytest.mark.parametrize("text", [["apple", "orange", "juice"]])
|
||||
def test_vectors_span_span_similarity(vocab, text):
|
||||
doc = Doc(vocab, words=text)
|
||||
with pytest.warns(UserWarning):
|
||||
assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2])
|
||||
assert -1.0 < doc[0:2].similarity(doc[1:3]) < 1.0
|
||||
assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2])
|
||||
assert -1.0 < doc[0:2].similarity(doc[1:3]) < 1.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", [["apple", "orange", "juice"]])
|
||||
def test_vectors_span_doc_similarity(vocab, text):
|
||||
doc = Doc(vocab, words=text)
|
||||
with pytest.warns(UserWarning):
|
||||
assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2])
|
||||
assert -1.0 < doc[0:2].similarity(doc) < 1.0
|
||||
assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2])
|
||||
assert -1.0 < doc[0:2].similarity(doc) < 1.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -607,7 +607,8 @@ cdef class Doc:
|
|||
if self.vocab.vectors.n_keys == 0:
|
||||
warnings.warn(Warnings.W007.format(obj="Doc"))
|
||||
if self.vector_norm == 0 or other.vector_norm == 0:
|
||||
warnings.warn(Warnings.W008.format(obj="Doc"))
|
||||
if not self.has_vector or not other.has_vector:
|
||||
warnings.warn(Warnings.W008.format(obj="Doc"))
|
||||
return 0.0
|
||||
vector = self.vector
|
||||
xp = get_array_module(vector)
|
||||
|
@ -627,7 +628,7 @@ cdef class Doc:
|
|||
if "has_vector" in self.user_hooks:
|
||||
return self.user_hooks["has_vector"](self)
|
||||
elif self.vocab.vectors.size:
|
||||
return True
|
||||
return any(token.has_vector for token in self)
|
||||
elif self.tensor.size:
|
||||
return True
|
||||
else:
|
||||
|
|
|
@ -354,7 +354,8 @@ cdef class Span:
|
|||
if self.vocab.vectors.n_keys == 0:
|
||||
warnings.warn(Warnings.W007.format(obj="Span"))
|
||||
if self.vector_norm == 0.0 or other.vector_norm == 0.0:
|
||||
warnings.warn(Warnings.W008.format(obj="Span"))
|
||||
if not self.has_vector or not other.has_vector:
|
||||
warnings.warn(Warnings.W008.format(obj="Span"))
|
||||
return 0.0
|
||||
vector = self.vector
|
||||
xp = get_array_module(vector)
|
||||
|
|
|
@ -206,7 +206,8 @@ cdef class Token:
|
|||
if self.vocab.vectors.n_keys == 0:
|
||||
warnings.warn(Warnings.W007.format(obj="Token"))
|
||||
if self.vector_norm == 0 or other.vector_norm == 0:
|
||||
warnings.warn(Warnings.W008.format(obj="Token"))
|
||||
if not self.has_vector or not other.has_vector:
|
||||
warnings.warn(Warnings.W008.format(obj="Token"))
|
||||
return 0.0
|
||||
vector = self.vector
|
||||
xp = get_array_module(vector)
|
||||
|
|
Loading…
Reference in New Issue
Block a user