Use warnings.warn instead of logger.warning

This commit is contained in:
Adriane Boyd 2021-06-04 17:44:04 +02:00
parent f0277bdeab
commit 9dfd3c9484
8 changed files with 16 additions and 22 deletions

View File

@ -24,6 +24,9 @@ def setup_default_warnings():
for pipe in ["matcher", "entity_ruler"]: for pipe in ["matcher", "entity_ruler"]:
filter_warning("once", error_msg=Warnings.W036.format(name=pipe)) filter_warning("once", error_msg=Warnings.W036.format(name=pipe))
# warn once about lemmatizer without required POS
filter_warning("once", error_msg="[W108]")
def filter_warning(action: str, error_msg: str): def filter_warning(action: str, error_msg: str):
"""Customize how spaCy should handle a certain warning. """Customize how spaCy should handle a certain warning.

View File

@ -689,7 +689,7 @@ class Language:
if self.vocab.vectors.shape != source.vocab.vectors.shape or \ if self.vocab.vectors.shape != source.vocab.vectors.shape or \
self.vocab.vectors.key2row != source.vocab.vectors.key2row or \ self.vocab.vectors.key2row != source.vocab.vectors.key2row or \
self.vocab.vectors.to_bytes() != source.vocab.vectors.to_bytes(): self.vocab.vectors.to_bytes() != source.vocab.vectors.to_bytes():
util.logger.warning(Warnings.W113.format(name=source_name)) warnings.warn(Warnings.W113.format(name=source_name))
if not source_name in source.component_names: if not source_name in source.component_names:
raise KeyError( raise KeyError(
Errors.E944.format( Errors.E944.format(

View File

@ -4,6 +4,7 @@ from collections import defaultdict
from itertools import product from itertools import product
import numpy import numpy
import warnings
from .matcher cimport Matcher from .matcher cimport Matcher
from ..vocab cimport Vocab from ..vocab cimport Vocab
@ -11,7 +12,6 @@ from ..tokens.doc cimport Doc
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..tokens import Span from ..tokens import Span
from ..util import logger
DELIMITER = "||" DELIMITER = "||"
@ -282,7 +282,7 @@ cdef class DependencyMatcher:
keys_to_position_maps = defaultdict(lambda: defaultdict(list)) keys_to_position_maps = defaultdict(lambda: defaultdict(list))
for match_id, start, end in self._matcher(doc): for match_id, start, end in self._matcher(doc):
if start + 1 != end: if start + 1 != end:
logger.warning(Warnings.W110.format(tokens=[t.text for t in doc[start:end]], pattern=self._matcher.get(match_id)[1][0][0])) warnings.warn(Warnings.W110.format(tokens=[t.text for t in doc[start:end]], pattern=self._matcher.get(match_id)[1][0][0]))
token = doc[start] token = doc[start]
root = ([token] + list(token.ancestors))[-1] root = ([token] + list(token.ancestors))[-1]
keys_to_position_maps[root.i][match_id].append(start) keys_to_position_maps[root.i][match_id].append(start)

View File

@ -2,6 +2,8 @@ from typing import Optional, List, Dict, Any, Callable, Iterable, Union, Tuple
from thinc.api import Model from thinc.api import Model
from pathlib import Path from pathlib import Path
import warnings
from .pipe import Pipe from .pipe import Pipe
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..language import Language from ..language import Language
@ -182,7 +184,7 @@ class Lemmatizer(Pipe):
univ_pos = token.pos_.lower() univ_pos = token.pos_.lower()
if univ_pos in ("", "eol", "space"): if univ_pos in ("", "eol", "space"):
if univ_pos == "": if univ_pos == "":
logger.warning(Warnings.W108.format(text=string)) warnings.warn(Warnings.W108.format(text=string))
return [string.lower()] return [string.lower()]
# See Issue #435 for example of where this logic is requied. # See Issue #435 for example of where this logic is requied.
if self.is_base_form(token): if self.is_base_form(token):

View File

@ -2,8 +2,6 @@ import weakref
import pytest import pytest
import numpy import numpy
import logging
import mock
from spacy.lang.xx import MultiLanguage from spacy.lang.xx import MultiLanguage
from spacy.tokens import Doc, Span, Token from spacy.tokens import Doc, Span, Token
@ -158,13 +156,10 @@ def test_doc_api_serialize(en_tokenizer, text):
def inner_func(d1, d2): def inner_func(d1, d2):
return "hello!" return "hello!"
logger = logging.getLogger("spacy") _ = tokens.to_bytes() # noqa: F841
with mock.patch.object(logger, "warning") as mock_warning: with pytest.warns(UserWarning):
_ = tokens.to_bytes() # noqa: F841
mock_warning.assert_not_called()
tokens.user_hooks["similarity"] = inner_func tokens.user_hooks["similarity"] = inner_func
_ = tokens.to_bytes() # noqa: F841 _ = tokens.to_bytes() # noqa: F841
mock_warning.assert_called_once()
def test_doc_api_set_ents(en_tokenizer): def test_doc_api_set_ents(en_tokenizer):

View File

@ -1,6 +1,4 @@
import pytest import pytest
import logging
import mock
import pickle import pickle
from spacy import util, registry from spacy import util, registry
from spacy.lang.en import English from spacy.lang.en import English
@ -59,10 +57,10 @@ def test_lemmatizer_config(nlp):
# warning if no POS assigned # warning if no POS assigned
doc = nlp.make_doc("coping") doc = nlp.make_doc("coping")
logger = logging.getLogger("spacy") with pytest.warns(UserWarning):
with mock.patch.object(logger, "warning") as mock_warning:
doc = lemmatizer(doc) doc = lemmatizer(doc)
mock_warning.assert_called_once() # warns once by default
doc = lemmatizer(doc)
# works with POS # works with POS
doc = nlp.make_doc("coping") doc = nlp.make_doc("coping")

View File

@ -1,6 +1,4 @@
import pytest import pytest
import mock
import logging
from spacy.language import Language from spacy.language import Language
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.de import German from spacy.lang.de import German
@ -437,10 +435,8 @@ def test_pipe_factories_from_source_language_subclass():
nlp = English() nlp = English()
nlp.vocab.vectors.resize((1, 4)) nlp.vocab.vectors.resize((1, 4))
nlp.vocab.vectors.add("cat", vector=[1, 2, 3, 4]) nlp.vocab.vectors.add("cat", vector=[1, 2, 3, 4])
logger = logging.getLogger("spacy") with pytest.warns(UserWarning):
with mock.patch.object(logger, "warning") as mock_warning:
nlp.add_pipe("tagger", source=source_nlp) nlp.add_pipe("tagger", source=source_nlp)
mock_warning.assert_called()
def test_pipe_factories_from_source_custom(): def test_pipe_factories_from_source_custom():

View File

@ -1318,7 +1318,7 @@ cdef class Doc:
if "user_data_values" not in exclude: if "user_data_values" not in exclude:
serializers["user_data_values"] = lambda: srsly.msgpack_dumps(user_data_values) serializers["user_data_values"] = lambda: srsly.msgpack_dumps(user_data_values)
if "user_hooks" not in exclude and any((self.user_hooks, self.user_token_hooks, self.user_span_hooks)): if "user_hooks" not in exclude and any((self.user_hooks, self.user_token_hooks, self.user_span_hooks)):
util.logger.warning(Warnings.W109) warnings.warn(Warnings.W109)
return util.to_dict(serializers, exclude) return util.to_dict(serializers, exclude)
def from_dict(self, msg, *, exclude=tuple()): def from_dict(self, msg, *, exclude=tuple()):