mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
💫 Add .similarity warnings for no vectors and option to exclude warnings (#2197)
* Add logic to filter out warning IDs via environment variable Usage: SPACY_WARNING_EXCLUDE=W001,W007 * Add warnings for empty vectors * Add warning if no word vectors are used in .similarity methods For example, if only tensors are available in small models – should hopefully clear up some confusion around this * Capture warnings in tests * Rename SPACY_WARNING_EXCLUDE to SPACY_WARNING_IGNORE
This commit is contained in:
parent
b096b22c20
commit
cae4457c38
|
@ -38,6 +38,14 @@ class Warnings(object):
|
|||
"surprising to you, make sure the Doc was processed using a model "
|
||||
"that supports named entity recognition, and check the `doc.ents` "
|
||||
"property manually if necessary.")
|
||||
W007 = ("The model you're using has no word vectors loaded, so the result "
|
||||
"of the {obj}.similarity method will be based on the tagger, "
|
||||
"parser and NER, which may not give useful similarity judgements. "
|
||||
"This may happen if you're using one of the small models, e.g. "
|
||||
"`en_core_web_sm`, which don't ship with word vectors and only "
|
||||
"use context-sensitive tensors. You can always add your own word "
|
||||
"vectors, or use one of the larger models instead if available.")
|
||||
W008 = ("Evaluating {obj}.similarity based on empty vectors.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
@ -286,8 +294,15 @@ def _get_warn_types(arg):
|
|||
if w_type.strip() in WARNINGS]
|
||||
|
||||
|
||||
def _get_warn_excl(arg):
|
||||
if not arg:
|
||||
return []
|
||||
return [w_id.strip() for w_id in arg.split(',')]
|
||||
|
||||
|
||||
SPACY_WARNING_FILTER = os.environ.get('SPACY_WARNING_FILTER', 'always')
|
||||
SPACY_WARNING_TYPES = _get_warn_types(os.environ.get('SPACY_WARNING_TYPES'))
|
||||
SPACY_WARNING_IGNORE = _get_warn_excl(os.environ.get('SPACY_WARNING_IGNORE'))
|
||||
|
||||
|
||||
def user_warning(message):
|
||||
|
@ -307,7 +322,8 @@ def _warn(message, warn_type='user'):
|
|||
message (unicode): The message to display.
|
||||
category (Warning): The Warning to show.
|
||||
"""
|
||||
if warn_type in SPACY_WARNING_TYPES:
|
||||
w_id = message.split('[', 1)[1].split(']', 1)[0] # get ID from string
|
||||
if warn_type in SPACY_WARNING_TYPES and w_id not in SPACY_WARNING_IGNORE:
|
||||
category = WARNINGS[warn_type]
|
||||
stack = inspect.stack()[-1]
|
||||
with warnings.catch_warnings():
|
||||
|
|
|
@ -15,7 +15,7 @@ from .attrs cimport IS_TITLE, IS_UPPER, LIKE_URL, LIKE_NUM, LIKE_EMAIL, IS_STOP
|
|||
from .attrs cimport IS_BRACKET, IS_QUOTE, IS_LEFT_PUNCT, IS_RIGHT_PUNCT, IS_CURRENCY, IS_OOV
|
||||
from .attrs cimport PROB
|
||||
from .attrs import intify_attrs
|
||||
from .errors import Errors
|
||||
from .errors import Errors, Warnings, user_warning
|
||||
|
||||
|
||||
memset(&EMPTY_LEXEME, 0, sizeof(LexemeC))
|
||||
|
@ -122,6 +122,7 @@ cdef class Lexeme:
|
|||
if self.c.orth == other[0].orth:
|
||||
return 1.0
|
||||
if self.vector_norm == 0 or other.vector_norm == 0:
|
||||
user_warning(Warnings.W008.format(obj='Lexeme'))
|
||||
return 0.0
|
||||
return (numpy.dot(self.vector, other.vector) /
|
||||
(self.vector_norm * other.vector_norm))
|
||||
|
|
|
@ -253,9 +253,11 @@ def test_doc_api_has_vector():
|
|||
|
||||
def test_doc_api_similarity_match():
|
||||
doc = Doc(Vocab(), words=['a'])
|
||||
with pytest.warns(None):
|
||||
assert doc.similarity(doc[0]) == 1.0
|
||||
assert doc.similarity(doc.vocab['a']) == 1.0
|
||||
doc2 = Doc(doc.vocab, words=['a', 'b', 'c'])
|
||||
with pytest.warns(None):
|
||||
assert doc.similarity(doc2[:1]) == 1.0
|
||||
assert doc.similarity(doc2) == 0.0
|
||||
|
||||
|
|
|
@ -88,6 +88,7 @@ def test_span_similarity_match():
|
|||
doc = Doc(Vocab(), words=['a', 'b', 'a', 'b'])
|
||||
span1 = doc[:2]
|
||||
span2 = doc[2:]
|
||||
with pytest.warns(None):
|
||||
assert span1.similarity(span2) == 1.0
|
||||
assert span1.similarity(doc) == 0.0
|
||||
assert span1[:1].similarity(doc.vocab['a']) == 1.0
|
||||
|
|
|
@ -45,6 +45,7 @@ def test_vectors_similarity_TT(vocab, vectors):
|
|||
def test_vectors_similarity_TD(vocab, vectors):
|
||||
[(word1, vec1), (word2, vec2)] = vectors
|
||||
doc = get_doc(vocab, words=[word1, word2])
|
||||
with pytest.warns(None):
|
||||
assert doc.similarity(doc[0]) == doc[0].similarity(doc)
|
||||
|
||||
|
||||
|
@ -57,4 +58,5 @@ def test_vectors_similarity_DS(vocab, vectors):
|
|||
def test_vectors_similarity_TS(vocab, vectors):
|
||||
[(word1, vec1), (word2, vec2)] = vectors
|
||||
doc = get_doc(vocab, words=[word1, word2])
|
||||
with pytest.warns(None):
|
||||
assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2])
|
||||
|
|
|
@ -206,6 +206,7 @@ def test_vectors_lexeme_doc_similarity(vocab, text):
|
|||
@pytest.mark.parametrize('text', [["apple", "orange", "juice"]])
|
||||
def test_vectors_span_span_similarity(vocab, text):
|
||||
doc = get_doc(vocab, text)
|
||||
with pytest.warns(None):
|
||||
assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2])
|
||||
assert -1. < doc[0:2].similarity(doc[1:3]) < 1.0
|
||||
|
||||
|
@ -213,6 +214,7 @@ def test_vectors_span_span_similarity(vocab, text):
|
|||
@pytest.mark.parametrize('text', [["apple", "orange", "juice"]])
|
||||
def test_vectors_span_doc_similarity(vocab, text):
|
||||
doc = get_doc(vocab, text)
|
||||
with pytest.warns(None):
|
||||
assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2])
|
||||
assert -1. < doc[0:2].similarity(doc) < 1.0
|
||||
|
||||
|
|
|
@ -31,7 +31,8 @@ from ..attrs cimport ENT_TYPE, SENT_START
|
|||
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
|
||||
from ..util import normalize_slice
|
||||
from ..compat import is_config, copy_reg, pickle, basestring_
|
||||
from ..errors import Errors, Warnings, deprecation_warning
|
||||
from ..errors import deprecation_warning, models_warning, user_warning
|
||||
from ..errors import Errors, Warnings
|
||||
from .. import util
|
||||
from .underscore import Underscore, get_ext_args
|
||||
from ._retokenize import Retokenizer
|
||||
|
@ -318,8 +319,10 @@ cdef class Doc:
|
|||
break
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
if self.vocab.vectors.n_keys == 0:
|
||||
models_warning(Warnings.W007.format(obj='Doc'))
|
||||
if self.vector_norm == 0 or other.vector_norm == 0:
|
||||
user_warning(Warnings.W008.format(obj='Doc'))
|
||||
return 0.0
|
||||
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from ..util import normalize_slice
|
|||
from ..attrs cimport IS_PUNCT, IS_SPACE
|
||||
from ..lexeme cimport Lexeme
|
||||
from ..compat import is_config
|
||||
from ..errors import Errors, TempErrors
|
||||
from ..errors import Errors, TempErrors, Warnings, user_warning, models_warning
|
||||
from .underscore import Underscore, get_ext_args
|
||||
|
||||
|
||||
|
@ -200,7 +200,10 @@ cdef class Span:
|
|||
break
|
||||
else:
|
||||
return 1.0
|
||||
if self.vocab.vectors.n_keys == 0:
|
||||
models_warning(Warnings.W007.format(obj='Span'))
|
||||
if self.vector_norm == 0.0 or other.vector_norm == 0.0:
|
||||
user_warning(Warnings.W008.format(obj='Span'))
|
||||
return 0.0
|
||||
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from ..attrs cimport IS_OOV, IS_TITLE, IS_UPPER, IS_CURRENCY, LIKE_URL, LIKE_NUM
|
|||
from ..attrs cimport IS_STOP, ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX
|
||||
from ..attrs cimport LENGTH, CLUSTER, LEMMA, POS, TAG, DEP
|
||||
from ..compat import is_config
|
||||
from ..errors import Errors
|
||||
from ..errors import Errors, Warnings, user_warning, models_warning
|
||||
from .. import util
|
||||
from .underscore import Underscore, get_ext_args
|
||||
|
||||
|
@ -161,7 +161,10 @@ cdef class Token:
|
|||
elif hasattr(other, 'orth'):
|
||||
if self.c.lex.orth == other.orth:
|
||||
return 1.0
|
||||
if self.vocab.vectors.n_keys == 0:
|
||||
models_warning(Warnings.W007.format(obj='Token'))
|
||||
if self.vector_norm == 0 or other.vector_norm == 0:
|
||||
user_warning(Warnings.W008.format(obj='Token'))
|
||||
return 0.0
|
||||
return (numpy.dot(self.vector, other.vector) /
|
||||
(self.vector_norm * other.vector_norm))
|
||||
|
|
Loading…
Reference in New Issue
Block a user