mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Update vector handling in similarity methods (#11013)
Distinguish between vectors that are 0 vs. missing vectors when warning about missing vectors. Update `Doc.has_vector` to match `Span.has_vector` and `Token.has_vector` for cases where the vocab has vectors but none of the tokens in the container have vectors.
This commit is contained in:
parent
1d5cad0b42
commit
24f4908fce
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import numpy
|
import numpy
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
from spacy.vocab import Vocab
|
||||||
|
|
||||||
from ..util import get_cosine, add_vecs_to_vocab
|
from ..util import get_cosine, add_vecs_to_vocab
|
||||||
|
|
||||||
|
@ -71,7 +72,6 @@ def test_vectors_similarity_DD(vocab, vectors):
|
||||||
def test_vectors_similarity_TD(vocab, vectors):
|
def test_vectors_similarity_TD(vocab, vectors):
|
||||||
[(word1, vec1), (word2, vec2)] = vectors
|
[(word1, vec1), (word2, vec2)] = vectors
|
||||||
doc = Doc(vocab, words=[word1, word2])
|
doc = Doc(vocab, words=[word1, word2])
|
||||||
with pytest.warns(UserWarning):
|
|
||||||
assert isinstance(doc.similarity(doc[0]), float)
|
assert isinstance(doc.similarity(doc[0]), float)
|
||||||
assert isinstance(doc[0].similarity(doc), float)
|
assert isinstance(doc[0].similarity(doc), float)
|
||||||
assert doc.similarity(doc[0]) == doc[0].similarity(doc)
|
assert doc.similarity(doc[0]) == doc[0].similarity(doc)
|
||||||
|
@ -80,9 +80,8 @@ def test_vectors_similarity_TD(vocab, vectors):
|
||||||
def test_vectors_similarity_TS(vocab, vectors):
|
def test_vectors_similarity_TS(vocab, vectors):
|
||||||
[(word1, vec1), (word2, vec2)] = vectors
|
[(word1, vec1), (word2, vec2)] = vectors
|
||||||
doc = Doc(vocab, words=[word1, word2])
|
doc = Doc(vocab, words=[word1, word2])
|
||||||
with pytest.warns(UserWarning):
|
|
||||||
assert isinstance(doc[:2].similarity(doc[0]), float)
|
assert isinstance(doc[:2].similarity(doc[0]), float)
|
||||||
assert isinstance(doc[0].similarity(doc[-2]), float)
|
assert isinstance(doc[0].similarity(doc[:2]), float)
|
||||||
assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2])
|
assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2])
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,3 +90,21 @@ def test_vectors_similarity_DS(vocab, vectors):
|
||||||
doc = Doc(vocab, words=[word1, word2])
|
doc = Doc(vocab, words=[word1, word2])
|
||||||
assert isinstance(doc.similarity(doc[:2]), float)
|
assert isinstance(doc.similarity(doc[:2]), float)
|
||||||
assert doc.similarity(doc[:2]) == doc[:2].similarity(doc)
|
assert doc.similarity(doc[:2]) == doc[:2].similarity(doc)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vectors_similarity_no_vectors():
|
||||||
|
vocab = Vocab()
|
||||||
|
doc1 = Doc(vocab, words=["a", "b"])
|
||||||
|
doc2 = Doc(vocab, words=["c", "d", "e"])
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
doc1.similarity(doc2)
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
doc1.similarity(doc2[1])
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
doc1.similarity(doc2[:2])
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
doc2.similarity(doc1)
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
doc2[1].similarity(doc1)
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
doc2[:2].similarity(doc1)
|
||||||
|
|
|
@ -318,7 +318,6 @@ def test_vectors_lexeme_doc_similarity(vocab, text):
|
||||||
@pytest.mark.parametrize("text", [["apple", "orange", "juice"]])
|
@pytest.mark.parametrize("text", [["apple", "orange", "juice"]])
|
||||||
def test_vectors_span_span_similarity(vocab, text):
|
def test_vectors_span_span_similarity(vocab, text):
|
||||||
doc = Doc(vocab, words=text)
|
doc = Doc(vocab, words=text)
|
||||||
with pytest.warns(UserWarning):
|
|
||||||
assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2])
|
assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2])
|
||||||
assert -1.0 < doc[0:2].similarity(doc[1:3]) < 1.0
|
assert -1.0 < doc[0:2].similarity(doc[1:3]) < 1.0
|
||||||
|
|
||||||
|
@ -326,7 +325,6 @@ def test_vectors_span_span_similarity(vocab, text):
|
||||||
@pytest.mark.parametrize("text", [["apple", "orange", "juice"]])
|
@pytest.mark.parametrize("text", [["apple", "orange", "juice"]])
|
||||||
def test_vectors_span_doc_similarity(vocab, text):
|
def test_vectors_span_doc_similarity(vocab, text):
|
||||||
doc = Doc(vocab, words=text)
|
doc = Doc(vocab, words=text)
|
||||||
with pytest.warns(UserWarning):
|
|
||||||
assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2])
|
assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2])
|
||||||
assert -1.0 < doc[0:2].similarity(doc) < 1.0
|
assert -1.0 < doc[0:2].similarity(doc) < 1.0
|
||||||
|
|
||||||
|
|
|
@ -607,6 +607,7 @@ cdef class Doc:
|
||||||
if self.vocab.vectors.n_keys == 0:
|
if self.vocab.vectors.n_keys == 0:
|
||||||
warnings.warn(Warnings.W007.format(obj="Doc"))
|
warnings.warn(Warnings.W007.format(obj="Doc"))
|
||||||
if self.vector_norm == 0 or other.vector_norm == 0:
|
if self.vector_norm == 0 or other.vector_norm == 0:
|
||||||
|
if not self.has_vector or not other.has_vector:
|
||||||
warnings.warn(Warnings.W008.format(obj="Doc"))
|
warnings.warn(Warnings.W008.format(obj="Doc"))
|
||||||
return 0.0
|
return 0.0
|
||||||
vector = self.vector
|
vector = self.vector
|
||||||
|
@ -627,7 +628,7 @@ cdef class Doc:
|
||||||
if "has_vector" in self.user_hooks:
|
if "has_vector" in self.user_hooks:
|
||||||
return self.user_hooks["has_vector"](self)
|
return self.user_hooks["has_vector"](self)
|
||||||
elif self.vocab.vectors.size:
|
elif self.vocab.vectors.size:
|
||||||
return True
|
return any(token.has_vector for token in self)
|
||||||
elif self.tensor.size:
|
elif self.tensor.size:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -354,6 +354,7 @@ cdef class Span:
|
||||||
if self.vocab.vectors.n_keys == 0:
|
if self.vocab.vectors.n_keys == 0:
|
||||||
warnings.warn(Warnings.W007.format(obj="Span"))
|
warnings.warn(Warnings.W007.format(obj="Span"))
|
||||||
if self.vector_norm == 0.0 or other.vector_norm == 0.0:
|
if self.vector_norm == 0.0 or other.vector_norm == 0.0:
|
||||||
|
if not self.has_vector or not other.has_vector:
|
||||||
warnings.warn(Warnings.W008.format(obj="Span"))
|
warnings.warn(Warnings.W008.format(obj="Span"))
|
||||||
return 0.0
|
return 0.0
|
||||||
vector = self.vector
|
vector = self.vector
|
||||||
|
|
|
@ -206,6 +206,7 @@ cdef class Token:
|
||||||
if self.vocab.vectors.n_keys == 0:
|
if self.vocab.vectors.n_keys == 0:
|
||||||
warnings.warn(Warnings.W007.format(obj="Token"))
|
warnings.warn(Warnings.W007.format(obj="Token"))
|
||||||
if self.vector_norm == 0 or other.vector_norm == 0:
|
if self.vector_norm == 0 or other.vector_norm == 0:
|
||||||
|
if not self.has_vector or not other.has_vector:
|
||||||
warnings.warn(Warnings.W008.format(obj="Token"))
|
warnings.warn(Warnings.W008.format(obj="Token"))
|
||||||
return 0.0
|
return 0.0
|
||||||
vector = self.vector
|
vector = self.vector
|
||||||
|
|
Loading…
Reference in New Issue
Block a user