diff --git a/spacy/errors.py b/spacy/errors.py index a557be2e8..b812a6f76 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -38,6 +38,14 @@ class Warnings(object): "surprising to you, make sure the Doc was processed using a model " "that supports named entity recognition, and check the `doc.ents` " "property manually if necessary.") + W007 = ("The model you're using has no word vectors loaded, so the result " + "of the {obj}.similarity method will be based on the tagger, " + "parser and NER, which may not give useful similarity judgements. " + "This may happen if you're using one of the small models, e.g. " + "`en_core_web_sm`, which don't ship with word vectors and only " + "use context-sensitive tensors. You can always add your own word " + "vectors, or use one of the larger models instead if available.") + W008 = ("Evaluating {obj}.similarity based on empty vectors.") @add_codes @@ -286,8 +294,15 @@ def _get_warn_types(arg): if w_type.strip() in WARNINGS] +def _get_warn_excl(arg): + if not arg: + return [] + return [w_id.strip() for w_id in arg.split(',')] + + SPACY_WARNING_FILTER = os.environ.get('SPACY_WARNING_FILTER', 'always') SPACY_WARNING_TYPES = _get_warn_types(os.environ.get('SPACY_WARNING_TYPES')) +SPACY_WARNING_IGNORE = _get_warn_excl(os.environ.get('SPACY_WARNING_IGNORE')) def user_warning(message): @@ -307,7 +322,8 @@ def _warn(message, warn_type='user'): message (unicode): The message to display. category (Warning): The Warning to show. """ - if warn_type in SPACY_WARNING_TYPES: + w_id = message.split('[', 1)[1].split(']', 1)[0] # get ID from string + if warn_type in SPACY_WARNING_TYPES and w_id not in SPACY_WARNING_IGNORE: category = WARNINGS[warn_type] stack = inspect.stack()[-1] with warnings.catch_warnings(): diff --git a/spacy/lexeme.pyx b/spacy/lexeme.pyx index ca93df9bd..e85f1183b 100644 --- a/spacy/lexeme.pyx +++ b/spacy/lexeme.pyx @@ -15,7 +15,7 @@ from .attrs cimport IS_TITLE, IS_UPPER, LIKE_URL, LIKE_NUM, LIKE_EMAIL, IS_STOP from .attrs cimport IS_BRACKET, IS_QUOTE, IS_LEFT_PUNCT, IS_RIGHT_PUNCT, IS_CURRENCY, IS_OOV from .attrs cimport PROB from .attrs import intify_attrs -from .errors import Errors +from .errors import Errors, Warnings, user_warning memset(&EMPTY_LEXEME, 0, sizeof(LexemeC)) @@ -122,6 +122,7 @@ cdef class Lexeme: if self.c.orth == other[0].orth: return 1.0 if self.vector_norm == 0 or other.vector_norm == 0: + user_warning(Warnings.W008.format(obj='Lexeme')) return 0.0 return (numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)) diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index d9db0916b..10f99223b 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -253,11 +253,13 @@ def test_doc_api_has_vector(): def test_doc_api_similarity_match(): doc = Doc(Vocab(), words=['a']) - assert doc.similarity(doc[0]) == 1.0 - assert doc.similarity(doc.vocab['a']) == 1.0 + with pytest.warns(None): + assert doc.similarity(doc[0]) == 1.0 + assert doc.similarity(doc.vocab['a']) == 1.0 doc2 = Doc(doc.vocab, words=['a', 'b', 'c']) - assert doc.similarity(doc2[:1]) == 1.0 - assert doc.similarity(doc2) == 0.0 + with pytest.warns(None): + assert doc.similarity(doc2[:1]) == 1.0 + assert doc.similarity(doc2) == 0.0 def test_lowest_common_ancestor(en_tokenizer): diff --git a/spacy/tests/doc/test_span.py b/spacy/tests/doc/test_span.py index 4cbb8ed94..d355f06c5 100644 --- a/spacy/tests/doc/test_span.py +++ b/spacy/tests/doc/test_span.py @@ -88,9 +88,10 @@ def test_span_similarity_match(): doc = Doc(Vocab(), words=['a', 'b', 'a', 'b']) span1 = doc[:2] span2 = doc[2:] - assert span1.similarity(span2) == 1.0 - assert span1.similarity(doc) == 0.0 - assert span1[:1].similarity(doc.vocab['a']) == 1.0 + with pytest.warns(None): + assert span1.similarity(span2) == 1.0 + assert span1.similarity(doc) == 0.0 + assert span1[:1].similarity(doc.vocab['a']) == 1.0 def test_spans_default_sentiment(en_tokenizer): diff --git a/spacy/tests/vectors/test_similarity.py b/spacy/tests/vectors/test_similarity.py index f9c18adca..231e641de 100644 --- a/spacy/tests/vectors/test_similarity.py +++ b/spacy/tests/vectors/test_similarity.py @@ -45,7 +45,8 @@ def test_vectors_similarity_TT(vocab, vectors): def test_vectors_similarity_TD(vocab, vectors): [(word1, vec1), (word2, vec2)] = vectors doc = get_doc(vocab, words=[word1, word2]) - assert doc.similarity(doc[0]) == doc[0].similarity(doc) + with pytest.warns(None): + assert doc.similarity(doc[0]) == doc[0].similarity(doc) def test_vectors_similarity_DS(vocab, vectors): @@ -57,4 +58,5 @@ def test_vectors_similarity_DS(vocab, vectors): def test_vectors_similarity_TS(vocab, vectors): [(word1, vec1), (word2, vec2)] = vectors doc = get_doc(vocab, words=[word1, word2]) - assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2]) + with pytest.warns(None): + assert doc[:2].similarity(doc[0]) == doc[0].similarity(doc[:2]) diff --git a/spacy/tests/vectors/test_vectors.py b/spacy/tests/vectors/test_vectors.py index c72777c07..831fbf003 100644 --- a/spacy/tests/vectors/test_vectors.py +++ b/spacy/tests/vectors/test_vectors.py @@ -206,15 +206,17 @@ def test_vectors_lexeme_doc_similarity(vocab, text): @pytest.mark.parametrize('text', [["apple", "orange", "juice"]]) def test_vectors_span_span_similarity(vocab, text): doc = get_doc(vocab, text) - assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2]) - assert -1. < doc[0:2].similarity(doc[1:3]) < 1.0 + with pytest.warns(None): + assert doc[0:2].similarity(doc[1:3]) == doc[1:3].similarity(doc[0:2]) + assert -1. < doc[0:2].similarity(doc[1:3]) < 1.0 @pytest.mark.parametrize('text', [["apple", "orange", "juice"]]) def test_vectors_span_doc_similarity(vocab, text): doc = get_doc(vocab, text) - assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2]) - assert -1. < doc[0:2].similarity(doc) < 1.0 + with pytest.warns(None): + assert doc[0:2].similarity(doc) == doc.similarity(doc[0:2]) + assert -1. < doc[0:2].similarity(doc) < 1.0 @pytest.mark.parametrize('text1,text2', [ diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 6f8dcebde..e4d8fc269 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -31,7 +31,8 @@ from ..attrs cimport ENT_TYPE, SENT_START from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t from ..util import normalize_slice from ..compat import is_config, copy_reg, pickle, basestring_ -from ..errors import Errors, Warnings, deprecation_warning +from ..errors import deprecation_warning, models_warning, user_warning +from ..errors import Errors, Warnings from .. import util from .underscore import Underscore, get_ext_args from ._retokenize import Retokenizer @@ -318,8 +319,10 @@ cdef class Doc: break else: return 1.0 - + if self.vocab.vectors.n_keys == 0: + models_warning(Warnings.W007.format(obj='Doc')) if self.vector_norm == 0 or other.vector_norm == 0: + user_warning(Warnings.W008.format(obj='Doc')) return 0.0 return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm) diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 35ca8eaaf..b2a5a02cb 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -16,7 +16,7 @@ from ..util import normalize_slice from ..attrs cimport IS_PUNCT, IS_SPACE from ..lexeme cimport Lexeme from ..compat import is_config -from ..errors import Errors, TempErrors +from ..errors import Errors, TempErrors, Warnings, user_warning, models_warning from .underscore import Underscore, get_ext_args @@ -200,7 +200,10 @@ cdef class Span: break else: return 1.0 + if self.vocab.vectors.n_keys == 0: + models_warning(Warnings.W007.format(obj='Span')) if self.vector_norm == 0.0 or other.vector_norm == 0.0: + user_warning(Warnings.W008.format(obj='Span')) return 0.0 return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm) diff --git a/spacy/tokens/token.pyx b/spacy/tokens/token.pyx index 0cc29392b..4b0a16d3d 100644 --- a/spacy/tokens/token.pyx +++ b/spacy/tokens/token.pyx @@ -19,7 +19,7 @@ from ..attrs cimport IS_OOV, IS_TITLE, IS_UPPER, IS_CURRENCY, LIKE_URL, LIKE_NUM from ..attrs cimport IS_STOP, ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX from ..attrs cimport LENGTH, CLUSTER, LEMMA, POS, TAG, DEP from ..compat import is_config -from ..errors import Errors +from ..errors import Errors, Warnings, user_warning, models_warning from .. import util from .underscore import Underscore, get_ext_args @@ -161,7 +161,10 @@ cdef class Token: elif hasattr(other, 'orth'): if self.c.lex.orth == other.orth: return 1.0 + if self.vocab.vectors.n_keys == 0: + models_warning(Warnings.W007.format(obj='Token')) if self.vector_norm == 0 or other.vector_norm == 0: + user_warning(Warnings.W008.format(obj='Token')) return 0.0 return (numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm))