💫 Add .similarity warnings for no vectors and option to exclude warnings (#2197)

* Add logic to filter out warning IDs via environment variable

Usage: SPACY_WARNING_EXCLUDE=W001,W007

* Add warnings for empty vectors

* Add warning if no word vectors are used in .similarity methods

For example, if only tensors are available in small models – should hopefully clear up some confusion around this

* Capture warnings in tests

* Rename SPACY_WARNING_EXCLUDE to SPACY_WARNING_IGNORE
This commit is contained in:
Ines Montani 2018-05-21 01:22:38 +02:00 committed by Matthew Honnibal
parent b096b22c20
commit cae4457c38
9 changed files with 52 additions and 19 deletions

View File

@ -38,6 +38,14 @@ class Warnings(object):
"surprising to you, make sure the Doc was processed using a model "
"that supports named entity recognition, and check the `doc.ents` "
"property manually if necessary.")
W007 = ("The model you're using has no word vectors loaded, so the result "
"of the {obj}.similarity method will be based on the tagger, "
"parser and NER, which may not give useful similarity judgements. "
"This may happen if you're using one of the small models, e.g. "
"`en_core_web_sm`, which don't ship with word vectors and only "
"use context-sensitive tensors. You can always add your own word "
"vectors, or use one of the larger models instead if available.")
W008 = ("Evaluating {obj}.similarity based on empty vectors.")
@add_codes
@ -286,8 +294,15 @@ def _get_warn_types(arg):
if w_type.strip() in WARNINGS]
def _get_warn_excl(arg):
if not arg:
return []
return [w_id.strip() for w_id in arg.split(',')]
SPACY_WARNING_FILTER = os.environ.get('SPACY_WARNING_FILTER', 'always')
SPACY_WARNING_TYPES = _get_warn_types(os.environ.get('SPACY_WARNING_TYPES'))
SPACY_WARNING_IGNORE = _get_warn_excl(os.environ.get('SPACY_WARNING_IGNORE'))
def user_warning(message):
@ -307,7 +322,8 @@ def _warn(message, warn_type='user'):
message (unicode): The message to display.
category (Warning): The Warning to show.
"""
if warn_type in SPACY_WARNING_TYPES:
w_id = message.split('[', 1)[1].split(']', 1)[0] # get ID from string
if warn_type in SPACY_WARNING_TYPES and w_id not in SPACY_WARNING_IGNORE:
category = WARNINGS[warn_type]
stack = inspect.stack()[-1]
with warnings.catch_warnings():

View File

@ -15,7 +15,7 @@ from .attrs cimport IS_TITLE, IS_UPPER, LIKE_URL, LIKE_NUM, LIKE_EMAIL, IS_STOP
from .attrs cimport IS_BRACKET, IS_QUOTE, IS_LEFT_PUNCT, IS_RIGHT_PUNCT, IS_CURRENCY, IS_OOV
from .attrs cimport PROB
from .attrs import intify_attrs
from .errors import Errors
from .errors import Errors, Warnings, user_warning
memset(&EMPTY_LEXEME, 0, sizeof(LexemeC))
@ -122,6 +122,7 @@ cdef class Lexeme:
if self.c.orth == other[0].orth:
return 1.0
if self.vector_norm == 0 or other.vector_norm == 0:
user_warning(Warnings.W008.format(obj='Lexeme'))
return 0.0
return (numpy.dot(self.vector, other.vector) /
(self.vector_norm * other.vector_norm))

View File

@ -253,9 +253,11 @@ def test_doc_api_has_vector():
def test_doc_api_similarity_match():
doc = Doc(Vocab(), words=['a'])
with pytest.warns(None):
assert doc.similarity(doc[0]) == 1.0
assert doc.similarity(doc.vocab['a']) == 1.0
doc2 = Doc(doc.vocab, words=['a', 'b', 'c'])
with pytest.warns(None):
assert doc.similarity(doc2[:1]) == 1.0
assert doc.similarity(doc2) == 0.0

View File

@ -88,6 +88,7 @@ def test_span_similarity_match():
doc = Doc(Vocab(), words=['a', 'b', 'a', 'b'])
span1 = doc[:2]
span2 = doc[2:]
with pytest.warns(None):
assert span1.similarity(span2) == 1.0
assert span1.similarity(doc) == 0.0
assert span1[:1].similarity(doc.vocab['a']) == 1.0

View File

@ -45,6 +45,7 @@ def test_vectors_similarity_TT(vocab, vectors):
def test_vectors_similarity_TD(vocab, vectors):
[(word1, vec1), (word2, vec2)] = vectors
doc = get_doc(vocab, words=[word1, word2])
with pytest.warns(None):
assert doc.similarity(doc[0]) == doc[0].similarity(doc)
@ -57,4 +58,5 @@ def test_vectors_similarity_DS(vocab, vectors):
def test_vectors_similarity_TS(vocab, vectors):
[(word1, vec1), (word2, vec2)] = vectors
doc = get_doc(vocab, words=[word1, word2])
with pytest.warns(None):
assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2])

View File

@ -206,6 +206,7 @@ def test_vectors_lexeme_doc_similarity(vocab, text):
@pytest.mark.parametrize('text', [["apple", "orange", "juice"]])
def test_vectors_span_span_similarity(vocab, text):
doc = get_doc(vocab, text)
with pytest.warns(None):
assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2])
assert -1. < doc[0:2].similarity(doc[1:3]) < 1.0
@ -213,6 +214,7 @@ def test_vectors_span_span_similarity(vocab, text):
@pytest.mark.parametrize('text', [["apple", "orange", "juice"]])
def test_vectors_span_doc_similarity(vocab, text):
doc = get_doc(vocab, text)
with pytest.warns(None):
assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2])
assert -1. < doc[0:2].similarity(doc) < 1.0

View File

@ -31,7 +31,8 @@ from ..attrs cimport ENT_TYPE, SENT_START
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
from ..util import normalize_slice
from ..compat import is_config, copy_reg, pickle, basestring_
from ..errors import Errors, Warnings, deprecation_warning
from ..errors import deprecation_warning, models_warning, user_warning
from ..errors import Errors, Warnings
from .. import util
from .underscore import Underscore, get_ext_args
from ._retokenize import Retokenizer
@ -318,8 +319,10 @@ cdef class Doc:
break
else:
return 1.0
if self.vocab.vectors.n_keys == 0:
models_warning(Warnings.W007.format(obj='Doc'))
if self.vector_norm == 0 or other.vector_norm == 0:
user_warning(Warnings.W008.format(obj='Doc'))
return 0.0
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)

View File

@ -16,7 +16,7 @@ from ..util import normalize_slice
from ..attrs cimport IS_PUNCT, IS_SPACE
from ..lexeme cimport Lexeme
from ..compat import is_config
from ..errors import Errors, TempErrors
from ..errors import Errors, TempErrors, Warnings, user_warning, models_warning
from .underscore import Underscore, get_ext_args
@ -200,7 +200,10 @@ cdef class Span:
break
else:
return 1.0
if self.vocab.vectors.n_keys == 0:
models_warning(Warnings.W007.format(obj='Span'))
if self.vector_norm == 0.0 or other.vector_norm == 0.0:
user_warning(Warnings.W008.format(obj='Span'))
return 0.0
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)

View File

@ -19,7 +19,7 @@ from ..attrs cimport IS_OOV, IS_TITLE, IS_UPPER, IS_CURRENCY, LIKE_URL, LIKE_NUM
from ..attrs cimport IS_STOP, ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX
from ..attrs cimport LENGTH, CLUSTER, LEMMA, POS, TAG, DEP
from ..compat import is_config
from ..errors import Errors
from ..errors import Errors, Warnings, user_warning, models_warning
from .. import util
from .underscore import Underscore, get_ext_args
@ -161,7 +161,10 @@ cdef class Token:
elif hasattr(other, 'orth'):
if self.c.lex.orth == other.orth:
return 1.0
if self.vocab.vectors.n_keys == 0:
models_warning(Warnings.W007.format(obj='Token'))
if self.vector_norm == 0 or other.vector_norm == 0:
user_warning(Warnings.W008.format(obj='Token'))
return 0.0
return (numpy.dot(self.vector, other.vector) /
(self.vector_norm * other.vector_norm))