mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
* Don't add duplicate patterns (fix #8216) * Refactor EntityRuler init This simplifies the EntityRuler init code. This is helpful as prep for allowing the EntityRuler to reset itself. * Make EntityRuler.clear reset matchers Includes a new test for this. * Tidy PhraseMatcher instantiation Since the attr can be None safely now, the guard if is no longer required here. Also renamed the `_validate` attr. Maybe it's not needed? * Fix NER test * Add test to make sure patterns aren't increasing * Move test to regression tests
This commit is contained in:
parent
d54631f68b
commit
d959603d51
|
@ -50,6 +50,8 @@ cdef class PhraseMatcher:
|
|||
if isinstance(attr, (int, long)):
|
||||
self.attr = attr
|
||||
else:
|
||||
if attr is None:
|
||||
attr = "ORTH"
|
||||
attr = attr.upper()
|
||||
if attr == "TEXT":
|
||||
attr = "ORTH"
|
||||
|
|
|
@ -101,17 +101,12 @@ class EntityRuler(Pipe):
|
|||
self.overwrite = overwrite_ents
|
||||
self.token_patterns = defaultdict(list)
|
||||
self.phrase_patterns = defaultdict(list)
|
||||
self._validate = validate
|
||||
self.matcher = Matcher(nlp.vocab, validate=validate)
|
||||
if phrase_matcher_attr is not None:
|
||||
if phrase_matcher_attr.upper() == "TEXT":
|
||||
phrase_matcher_attr = "ORTH"
|
||||
self.phrase_matcher_attr = phrase_matcher_attr
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
|
||||
)
|
||||
else:
|
||||
self.phrase_matcher_attr = None
|
||||
self.phrase_matcher = PhraseMatcher(nlp.vocab, validate=validate)
|
||||
self.phrase_matcher_attr = phrase_matcher_attr
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
|
||||
)
|
||||
self.ent_id_sep = ent_id_sep
|
||||
self._ent_ids = defaultdict(dict)
|
||||
if patterns is not None:
|
||||
|
@ -315,20 +310,22 @@ class EntityRuler(Pipe):
|
|||
pattern = entry["pattern"]
|
||||
if isinstance(pattern, Doc):
|
||||
self.phrase_patterns[label].append(pattern)
|
||||
self.phrase_matcher.add(label, [pattern])
|
||||
elif isinstance(pattern, list):
|
||||
self.token_patterns[label].append(pattern)
|
||||
self.matcher.add(label, [pattern])
|
||||
else:
|
||||
raise ValueError(Errors.E097.format(pattern=pattern))
|
||||
for label, patterns in self.token_patterns.items():
|
||||
self.matcher.add(label, patterns)
|
||||
for label, patterns in self.phrase_patterns.items():
|
||||
self.phrase_matcher.add(label, patterns)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Reset all patterns."""
|
||||
self.token_patterns = defaultdict(list)
|
||||
self.phrase_patterns = defaultdict(list)
|
||||
self._ent_ids = defaultdict(dict)
|
||||
self.matcher = Matcher(self.nlp.vocab, validate=self._validate)
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate
|
||||
)
|
||||
|
||||
def _split_label(self, label: str) -> Tuple[str, str]:
|
||||
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep
|
||||
|
@ -374,10 +371,9 @@ class EntityRuler(Pipe):
|
|||
self.add_patterns(cfg.get("patterns", cfg))
|
||||
self.overwrite = cfg.get("overwrite", False)
|
||||
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
|
||||
if self.phrase_matcher_attr is not None:
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
self.nlp.vocab, attr=self.phrase_matcher_attr
|
||||
)
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
self.nlp.vocab, attr=self.phrase_matcher_attr
|
||||
)
|
||||
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
|
||||
else:
|
||||
self.add_patterns(cfg)
|
||||
|
@ -428,10 +424,9 @@ class EntityRuler(Pipe):
|
|||
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr")
|
||||
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
|
||||
|
||||
if self.phrase_matcher_attr is not None:
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
self.nlp.vocab, attr=self.phrase_matcher_attr
|
||||
)
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
self.nlp.vocab, attr=self.phrase_matcher_attr
|
||||
)
|
||||
from_disk(path, deserializers_patterns, {})
|
||||
return self
|
||||
|
||||
|
|
|
@ -252,12 +252,12 @@ def test_ruler_before_ner():
|
|||
# 1 : Entity Ruler - should set "this" to B and everything else to empty
|
||||
patterns = [{"label": "THING", "pattern": "This"}]
|
||||
ruler = nlp.add_pipe("entity_ruler")
|
||||
ruler.add_patterns(patterns)
|
||||
|
||||
# 2: untrained NER - should set everything else to O
|
||||
untrained_ner = nlp.add_pipe("ner")
|
||||
untrained_ner.add_label("MY_LABEL")
|
||||
nlp.initialize()
|
||||
ruler.add_patterns(patterns)
|
||||
doc = nlp("This is Antti Korhonen speaking in Finland")
|
||||
expected_iobs = ["B", "O", "O", "O", "O", "O", "O"]
|
||||
expected_types = ["THING", "", "", "", "", "", ""]
|
||||
|
|
|
@ -78,6 +78,19 @@ def test_entity_ruler_init_clear(nlp, patterns):
|
|||
assert len(ruler.labels) == 0
|
||||
|
||||
|
||||
def test_entity_ruler_clear(nlp, patterns):
|
||||
"""Test that initialization clears patterns."""
|
||||
ruler = nlp.add_pipe("entity_ruler")
|
||||
ruler.add_patterns(patterns)
|
||||
assert len(ruler.labels) == 4
|
||||
doc = nlp("hello world")
|
||||
assert len(doc.ents) == 1
|
||||
ruler.clear()
|
||||
assert len(ruler.labels) == 0
|
||||
doc = nlp("hello world")
|
||||
assert len(doc.ents) == 0
|
||||
|
||||
|
||||
def test_entity_ruler_existing(nlp, patterns):
|
||||
ruler = nlp.add_pipe("entity_ruler")
|
||||
ruler.add_patterns(patterns)
|
||||
|
|
34
spacy/tests/regression/test_issue8216.py
Normal file
34
spacy/tests/regression/test_issue8216.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import pytest
|
||||
|
||||
from spacy import registry
|
||||
from spacy.language import Language
|
||||
from spacy.pipeline import EntityRuler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
return Language()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@registry.misc("entity_ruler_patterns")
|
||||
def patterns():
|
||||
return [
|
||||
{"label": "HELLO", "pattern": "hello world"},
|
||||
{"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]},
|
||||
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
|
||||
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
|
||||
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
|
||||
{"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"},
|
||||
]
|
||||
|
||||
|
||||
def test_entity_ruler_fix8216(nlp, patterns):
|
||||
"""Test that patterns don't get added excessively."""
|
||||
ruler = nlp.add_pipe("entity_ruler", config={"validate": True})
|
||||
ruler.add_patterns(patterns)
|
||||
pattern_count = sum(len(mm) for mm in ruler.matcher._patterns.values())
|
||||
assert pattern_count > 0
|
||||
ruler.add_patterns([])
|
||||
after_count = sum(len(mm) for mm in ruler.matcher._patterns.values())
|
||||
assert after_count == pattern_count
|
Loading…
Reference in New Issue
Block a user