diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index e5ff2202c..d8486b84b 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -50,6 +50,8 @@ cdef class PhraseMatcher: if isinstance(attr, (int, long)): self.attr = attr else: + if attr is None: + attr = "ORTH" attr = attr.upper() if attr == "TEXT": attr = "ORTH" diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 4e61dbca7..03730f772 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -101,17 +101,12 @@ class EntityRuler(Pipe): self.overwrite = overwrite_ents self.token_patterns = defaultdict(list) self.phrase_patterns = defaultdict(list) + self._validate = validate self.matcher = Matcher(nlp.vocab, validate=validate) - if phrase_matcher_attr is not None: - if phrase_matcher_attr.upper() == "TEXT": - phrase_matcher_attr = "ORTH" - self.phrase_matcher_attr = phrase_matcher_attr - self.phrase_matcher = PhraseMatcher( - nlp.vocab, attr=self.phrase_matcher_attr, validate=validate - ) - else: - self.phrase_matcher_attr = None - self.phrase_matcher = PhraseMatcher(nlp.vocab, validate=validate) + self.phrase_matcher_attr = phrase_matcher_attr + self.phrase_matcher = PhraseMatcher( + nlp.vocab, attr=self.phrase_matcher_attr, validate=validate + ) self.ent_id_sep = ent_id_sep self._ent_ids = defaultdict(dict) if patterns is not None: @@ -315,20 +310,22 @@ class EntityRuler(Pipe): pattern = entry["pattern"] if isinstance(pattern, Doc): self.phrase_patterns[label].append(pattern) + self.phrase_matcher.add(label, [pattern]) elif isinstance(pattern, list): self.token_patterns[label].append(pattern) + self.matcher.add(label, [pattern]) else: raise ValueError(Errors.E097.format(pattern=pattern)) - for label, patterns in self.token_patterns.items(): - self.matcher.add(label, patterns) - for label, patterns in self.phrase_patterns.items(): - self.phrase_matcher.add(label, patterns) def clear(self) -> None: """Reset all patterns.""" self.token_patterns = defaultdict(list) self.phrase_patterns = defaultdict(list) self._ent_ids = defaultdict(dict) + self.matcher = Matcher(self.nlp.vocab, validate=self._validate) + self.phrase_matcher = PhraseMatcher( + self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate + ) def _split_label(self, label: str) -> Tuple[str, str]: """Split Entity label into ent_label and ent_id if it contains self.ent_id_sep @@ -374,10 +371,9 @@ class EntityRuler(Pipe): self.add_patterns(cfg.get("patterns", cfg)) self.overwrite = cfg.get("overwrite", False) self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None) - if self.phrase_matcher_attr is not None: - self.phrase_matcher = PhraseMatcher( - self.nlp.vocab, attr=self.phrase_matcher_attr - ) + self.phrase_matcher = PhraseMatcher( + self.nlp.vocab, attr=self.phrase_matcher_attr + ) self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) else: self.add_patterns(cfg) @@ -428,10 +424,9 @@ class EntityRuler(Pipe): self.phrase_matcher_attr = cfg.get("phrase_matcher_attr") self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP) - if self.phrase_matcher_attr is not None: - self.phrase_matcher = PhraseMatcher( - self.nlp.vocab, attr=self.phrase_matcher_attr - ) + self.phrase_matcher = PhraseMatcher( + self.nlp.vocab, attr=self.phrase_matcher_attr + ) from_disk(path, deserializers_patterns, {}) return self diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index dffdff1ec..1b9d0b255 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -252,12 +252,12 @@ def test_ruler_before_ner(): # 1 : Entity Ruler - should set "this" to B and everything else to empty patterns = [{"label": "THING", "pattern": "This"}] ruler = nlp.add_pipe("entity_ruler") - ruler.add_patterns(patterns) # 2: untrained NER - should set everything else to O untrained_ner = nlp.add_pipe("ner") untrained_ner.add_label("MY_LABEL") nlp.initialize() + ruler.add_patterns(patterns) doc = nlp("This is Antti Korhonen speaking in Finland") expected_iobs = ["B", "O", "O", "O", "O", "O", "O"] expected_types = ["THING", "", "", "", "", "", ""] diff --git a/spacy/tests/pipeline/test_entity_ruler.py b/spacy/tests/pipeline/test_entity_ruler.py index 2f6da79d6..79ad44abd 100644 --- a/spacy/tests/pipeline/test_entity_ruler.py +++ b/spacy/tests/pipeline/test_entity_ruler.py @@ -78,6 +78,19 @@ def test_entity_ruler_init_clear(nlp, patterns): assert len(ruler.labels) == 0 +def test_entity_ruler_clear(nlp, patterns): + """Test that initialization clears patterns.""" + ruler = nlp.add_pipe("entity_ruler") + ruler.add_patterns(patterns) + assert len(ruler.labels) == 4 + doc = nlp("hello world") + assert len(doc.ents) == 1 + ruler.clear() + assert len(ruler.labels) == 0 + doc = nlp("hello world") + assert len(doc.ents) == 0 + + def test_entity_ruler_existing(nlp, patterns): ruler = nlp.add_pipe("entity_ruler") ruler.add_patterns(patterns) diff --git a/spacy/tests/regression/test_issue8216.py b/spacy/tests/regression/test_issue8216.py new file mode 100644 index 000000000..528d4b6f9 --- /dev/null +++ b/spacy/tests/regression/test_issue8216.py @@ -0,0 +1,34 @@ +import pytest + +from spacy import registry +from spacy.language import Language +from spacy.pipeline import EntityRuler + + +@pytest.fixture +def nlp(): + return Language() + + +@pytest.fixture +@registry.misc("entity_ruler_patterns") +def patterns(): + return [ + {"label": "HELLO", "pattern": "hello world"}, + {"label": "BYE", "pattern": [{"LOWER": "bye"}, {"LOWER": "bye"}]}, + {"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]}, + {"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]}, + {"label": "TECH_ORG", "pattern": "Apple", "id": "a1"}, + {"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"}, + ] + + +def test_entity_ruler_fix8216(nlp, patterns): + """Test that patterns don't get added excessively.""" + ruler = nlp.add_pipe("entity_ruler", config={"validate": True}) + ruler.add_patterns(patterns) + pattern_count = sum(len(mm) for mm in ruler.matcher._patterns.values()) + assert pattern_count > 0 + ruler.add_patterns([]) + after_count = sum(len(mm) for mm in ruler.matcher._patterns.values()) + assert after_count == pattern_count