Show warning if entity_ruler runs without patterns (#7807)

* Show warning if entity_ruler runs without patterns

* Show warning if matcher runs without patterns

* fix wording

* unit test for warning once (WIP)

* warn W036 only once

* cleanup

* create filter_warning helper
This commit is contained in:
Sofie Van Landeghem 2021-05-31 10:20:27 +02:00 committed by GitHub
parent d1a221a374
commit ff91e6dac7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 66 additions and 5 deletions

View File

@ -1,10 +1,10 @@
from typing import Union, Iterable, Dict, Any from typing import Union, Iterable, Dict, Any
from pathlib import Path from pathlib import Path
import warnings
import sys import sys
warnings.filterwarnings("ignore", message="numpy.dtype size changed") # noqa # set library-specific custom warning handling before doing anything else
warnings.filterwarnings("ignore", message="numpy.ufunc size changed") # noqa from .errors import setup_default_warnings
setup_default_warnings()
# These are imported as part of the API # These are imported as part of the API
from thinc.api import prefer_gpu, require_gpu, require_cpu # noqa: F401 from thinc.api import prefer_gpu, require_gpu, require_cpu # noqa: F401

View File

@ -1,3 +1,6 @@
import warnings
def add_codes(err_cls): def add_codes(err_cls):
"""Add error codes to string messages via class attribute names.""" """Add error codes to string messages via class attribute names."""
@ -12,6 +15,30 @@ def add_codes(err_cls):
return ErrorsWithCodes() return ErrorsWithCodes()
def setup_default_warnings():
# ignore certain numpy warnings
filter_warning("ignore", error_msg="numpy.dtype size changed") # noqa
filter_warning("ignore", error_msg="numpy.ufunc size changed") # noqa
# warn about entity_ruler & matcher having no patterns only once
for pipe in ["matcher", "entity_ruler"]:
filter_warning("once", error_msg=Warnings.W036.format(name=pipe))
def filter_warning(action: str, error_msg: str):
"""Customize how spaCy should handle a certain warning.
error_msg (str): e.g. "W006", or a full error message
action (str): "default", "error", "ignore", "always", "module" or "once"
"""
warnings.filterwarnings(action, message=_escape_warning_msg(error_msg))
def _escape_warning_msg(msg):
"""To filter with warnings.filterwarnings, the [] brackets need to be escaped"""
return msg.replace("[", "\\[").replace("]", "\\]")
# fmt: off # fmt: off
@add_codes @add_codes
@ -80,8 +107,9 @@ class Warnings:
"@misc = \"spacy.LookupsDataLoader.v1\"\n" "@misc = \"spacy.LookupsDataLoader.v1\"\n"
"lang = ${{nlp.lang}}\n" "lang = ${{nlp.lang}}\n"
"tables = [\"lexeme_norm\"]\n") "tables = [\"lexeme_norm\"]\n")
W035 = ('Discarding subpattern "{pattern}" due to an unrecognized ' W035 = ("Discarding subpattern '{pattern}' due to an unrecognized "
"attribute or operator.") "attribute or operator.")
W036 = ("The component '{name}' does not have any patterns defined.")
# New warnings added in v3.x # New warnings added in v3.x
W086 = ("Component '{listener}' will be (re)trained, but it needs the component " W086 = ("Component '{listener}' will be (re)trained, but it needs the component "

View File

@ -138,6 +138,11 @@ cdef class Matcher:
self._filter[key] = greedy self._filter[key] = greedy
self._patterns[key].extend(patterns) self._patterns[key].extend(patterns)
def _require_patterns(self) -> None:
"""Raise a warning if this component has no patterns defined."""
if len(self) == 0:
warnings.warn(Warnings.W036.format(name="matcher"))
def remove(self, key): def remove(self, key):
"""Remove a rule from the matcher. A KeyError is raised if the key does """Remove a rule from the matcher. A KeyError is raised if the key does
not exist. not exist.
@ -215,6 +220,7 @@ cdef class Matcher:
If with_alignments is set to True and as_spans is set to False, If with_alignments is set to True and as_spans is set to False,
A list of `(match_id, start, end, alignments)` tuples is returned. A list of `(match_id, start, end, alignments)` tuples is returned.
""" """
self._require_patterns()
if isinstance(doclike, Doc): if isinstance(doclike, Doc):
doc = doclike doc = doclike
length = len(doc) length = len(doc)

View File

@ -1,3 +1,4 @@
import warnings
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable, Sequence from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable, Sequence
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@ -6,7 +7,7 @@ import srsly
from .pipe import Pipe from .pipe import Pipe
from ..training import Example from ..training import Example
from ..language import Language from ..language import Language
from ..errors import Errors from ..errors import Errors, Warnings
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList
from ..tokens import Doc, Span from ..tokens import Doc, Span
from ..matcher import Matcher, PhraseMatcher from ..matcher import Matcher, PhraseMatcher
@ -144,6 +145,7 @@ class EntityRuler(Pipe):
error_handler(self.name, self, [doc], e) error_handler(self.name, self, [doc], e)
def match(self, doc: Doc): def match(self, doc: Doc):
self._require_patterns()
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc)) matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
matches = set( matches = set(
[(m_id, start, end) for m_id, start, end in matches if start != end] [(m_id, start, end) for m_id, start, end in matches if start != end]
@ -330,6 +332,11 @@ class EntityRuler(Pipe):
self.phrase_patterns = defaultdict(list) self.phrase_patterns = defaultdict(list)
self._ent_ids = defaultdict(dict) self._ent_ids = defaultdict(dict)
def _require_patterns(self) -> None:
"""Raise a warning if this component has no patterns defined."""
if len(self) == 0:
warnings.warn(Warnings.W036.format(name=self.name))
def _split_label(self, label: str) -> Tuple[str, str]: def _split_label(self, label: str) -> Tuple[str, str]:
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep """Split Entity label into ent_label and ent_id if it contains self.ent_id_sep

View File

@ -33,6 +33,15 @@ def test_matcher_from_api_docs(en_vocab):
assert len(patterns[0]) assert len(patterns[0])
def test_matcher_empty_patterns_warns(en_vocab):
matcher = Matcher(en_vocab)
assert len(matcher) == 0
doc = Doc(en_vocab, words=["This", "is", "quite", "something"])
with pytest.warns(UserWarning):
matcher(doc)
assert len(doc.ents) == 0
def test_matcher_from_usage_docs(en_vocab): def test_matcher_from_usage_docs(en_vocab):
text = "Wow 😀 This is really cool! 😂 😂" text = "Wow 😀 This is really cool! 😂 😂"
doc = Doc(en_vocab, words=text.split(" ")) doc = Doc(en_vocab, words=text.split(" "))

View File

@ -46,6 +46,17 @@ def test_entity_ruler_init(nlp, patterns):
assert doc.ents[1].label_ == "BYE" assert doc.ents[1].label_ == "BYE"
def test_entity_ruler_no_patterns_warns(nlp):
ruler = EntityRuler(nlp)
assert len(ruler) == 0
assert len(ruler.labels) == 0
nlp.add_pipe("entity_ruler")
assert nlp.pipe_names == ["entity_ruler"]
with pytest.warns(UserWarning):
doc = nlp("hello world bye bye")
assert len(doc.ents) == 0
def test_entity_ruler_init_patterns(nlp, patterns): def test_entity_ruler_init_patterns(nlp, patterns):
# initialize with patterns # initialize with patterns
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")