Don't add duplicate patterns all the time in EntityRuler (fix #8216) (#8246)

* Don't add duplicate patterns (fix #8216)

* Refactor EntityRuler init

This simplifies the EntityRuler init code. This is helpful as prep for
allowing the EntityRuler to reset itself.

* Make EntityRuler.clear reset matchers

Includes a new test for this.

* Tidy PhraseMatcher instantiation

Since the attr can be None safely now, the guard if is no longer
required here.

Also renamed the `_validate` attr. Maybe it's not needed?

* Fix NER test

* Add test to make sure patterns aren't increasing

* Move test to regression tests
This commit is contained in:
Paul O'Leary McCann 2021-06-03 16:05:26 +09:00 committed by GitHub
parent d54631f68b
commit d959603d51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 67 additions and 23 deletions

View File

@ -50,6 +50,8 @@ cdef class PhraseMatcher:
if isinstance(attr, (int, long)):
self.attr = attr
else:
if attr is None:
attr = "ORTH"
attr = attr.upper()
if attr == "TEXT":
attr = "ORTH"

View File

@ -101,17 +101,12 @@ class EntityRuler(Pipe):
self.overwrite = overwrite_ents
self.token_patterns = defaultdict(list)
self.phrase_patterns = defaultdict(list)
self._validate = validate
self.matcher = Matcher(nlp.vocab, validate=validate)
if phrase_matcher_attr is not None:
if phrase_matcher_attr.upper() == "TEXT":
phrase_matcher_attr = "ORTH"
self.phrase_matcher_attr = phrase_matcher_attr
self.phrase_matcher = PhraseMatcher(
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
)
else:
self.phrase_matcher_attr = None
self.phrase_matcher = PhraseMatcher(nlp.vocab, validate=validate)
self.phrase_matcher_attr = phrase_matcher_attr
self.phrase_matcher = PhraseMatcher(
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
)
self.ent_id_sep = ent_id_sep
self._ent_ids = defaultdict(dict)
if patterns is not None:
@ -315,20 +310,22 @@ class EntityRuler(Pipe):
pattern = entry["pattern"]
if isinstance(pattern, Doc):
self.phrase_patterns[label].append(pattern)
self.phrase_matcher.add(label, [pattern])
elif isinstance(pattern, list):
self.token_patterns[label].append(pattern)
self.matcher.add(label, [pattern])
else:
raise ValueError(Errors.E097.format(pattern=pattern))
for label, patterns in self.token_patterns.items():
self.matcher.add(label, patterns)
for label, patterns in self.phrase_patterns.items():
self.phrase_matcher.add(label, patterns)
def clear(self) -> None:
"""Reset all patterns."""
self.token_patterns = defaultdict(list)
self.phrase_patterns = defaultdict(list)
self._ent_ids = defaultdict(dict)
self.matcher = Matcher(self.nlp.vocab, validate=self._validate)
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate
)
def _split_label(self, label: str) -> Tuple[str, str]:
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep
@ -374,10 +371,9 @@ class EntityRuler(Pipe):
self.add_patterns(cfg.get("patterns", cfg))
self.overwrite = cfg.get("overwrite", False)
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr
)
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr
)
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
else:
self.add_patterns(cfg)
@ -428,10 +424,9 @@ class EntityRuler(Pipe):
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr")
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
if self.phrase_matcher_attr is not None:
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr
)
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr
)
from_disk(path, deserializers_patterns, {})
return self

View File

@ -252,12 +252,12 @@ def test_ruler_before_ner():
# 1 : Entity Ruler - should set "this" to B and everything else to empty
patterns = [{"label": "THING", "pattern": "This"}]
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
# 2: untrained NER - should set everything else to O
untrained_ner = nlp.add_pipe("ner")
untrained_ner.add_label("MY_LABEL")
nlp.initialize()
ruler.add_patterns(patterns)
doc = nlp("This is Antti Korhonen speaking in Finland")
expected_iobs = ["B", "O", "O", "O", "O", "O", "O"]
expected_types = ["THING", "", "", "", "", "", ""]

View File

@ -78,6 +78,19 @@ def test_entity_ruler_init_clear(nlp, patterns):
assert len(ruler.labels) == 0
def test_entity_ruler_clear(nlp, patterns):
"""Test that initialization clears patterns."""
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
assert len(ruler.labels) == 4
doc = nlp("hello world")
assert len(doc.ents) == 1
ruler.clear()
assert len(ruler.labels) == 0
doc = nlp("hello world")
assert len(doc.ents) == 0
def test_entity_ruler_existing(nlp, patterns):
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)

View File

@ -0,0 +1,34 @@
import pytest
from spacy import registry
from spacy.language import Language
from spacy.pipeline import EntityRuler
@pytest.fixture
def nlp():
return Language()
@pytest.fixture
@registry.misc("entity_ruler_patterns")
def patterns():
return [
{"label": "HELLO", "pattern": "hello world"},
{"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]},
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
{"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"},
]
def test_entity_ruler_fix8216(nlp, patterns):
"""Test that patterns don't get added excessively."""
ruler = nlp.add_pipe("entity_ruler", config={"validate": True})
ruler.add_patterns(patterns)
pattern_count = sum(len(mm) for mm in ruler.matcher._patterns.values())
assert pattern_count > 0
ruler.add_patterns([])
after_count = sum(len(mm) for mm in ruler.matcher._patterns.values())
assert after_count == pattern_count