diff --git a/spacy/errors.py b/spacy/errors.py index c5e364013..ad7a0280f 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -889,6 +889,8 @@ class Errors(metaclass=ErrorsWithCodes): "Non-UD tags should use the `tag` property.") E1022 = ("Words must be of type str or int, but input is of type '{wtype}'") E1023 = ("Couldn't read EntityRuler from the {path}. This file doesn't exist.") + E1024 = ("A pattern with ID \"{ent_id}\" is not present in EntityRuler patterns.") + # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/matcher/phrasematcher.pyi b/spacy/matcher/phrasematcher.pyi index d73633ec0..741bf7bb6 100644 --- a/spacy/matcher/phrasematcher.pyi +++ b/spacy/matcher/phrasematcher.pyi @@ -8,12 +8,9 @@ class PhraseMatcher: def __init__( self, vocab: Vocab, attr: Optional[Union[int, str]], validate: bool = ... ) -> None: ... - def __call__( - self, - doclike: Union[Doc, Span], - *, - as_spans: bool = ..., - ) -> Union[List[Tuple[int, int, int]], List[Span]]: ... + def __reduce__(self) -> Any: ... + def __len__(self) -> int: ... + def __contains__(self, key: str) -> bool: ... def add( self, key: str, @@ -23,3 +20,10 @@ class PhraseMatcher: Callable[[Matcher, Doc, int, List[Tuple[Any, ...]]], Any] ] = ..., ) -> None: ... + def remove(self, key: str) -> None: ... + def __call__( + self, + doclike: Union[Doc, Span], + *, + as_spans: bool = ..., + ) -> Union[List[Tuple[int, int, int]], List[Span]]: ... diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 78d7a0be2..614d71f41 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -348,6 +348,46 @@ class EntityRuler(Pipe): self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate ) + def remove(self, ent_id: str) -> None: + """Remove a pattern by its ent_id if a pattern with this ent_id was added before + + ent_id (str): id of the pattern to be removed + RETURNS: None + DOCS: https://spacy.io/api/entityruler#remove + """ + label_id_pairs = [ + (label, eid) for (label, eid) in self._ent_ids.values() if eid == ent_id + ] + if not label_id_pairs: + raise ValueError(Errors.E1024.format(ent_id=ent_id)) + created_labels = [ + self._create_label(label, eid) for (label, eid) in label_id_pairs + ] + # remove the patterns from self.phrase_patterns + self.phrase_patterns = defaultdict( + list, + { + label: val + for (label, val) in self.phrase_patterns.items() + if label not in created_labels + }, + ) + # remove the patterns from self.token_pattern + self.token_patterns = defaultdict( + list, + { + label: val + for (label, val) in self.token_patterns.items() + if label not in created_labels + }, + ) + # remove the patterns from self.token_pattern + for label in created_labels: + if label in self.phrase_matcher: + self.phrase_matcher.remove(label) + else: + self.matcher.remove(label) + def _require_patterns(self) -> None: """Raise a warning if this component has no patterns defined.""" if len(self) == 0: diff --git a/spacy/tests/pipeline/test_entity_ruler.py b/spacy/tests/pipeline/test_entity_ruler.py index 0cecafff3..f2031d0a9 100644 --- a/spacy/tests/pipeline/test_entity_ruler.py +++ b/spacy/tests/pipeline/test_entity_ruler.py @@ -373,3 +373,185 @@ def test_entity_ruler_serialize_dir(nlp, patterns): ruler.from_disk(d / "test_ruler") # read from an existing directory with pytest.raises(ValueError): ruler.from_disk(d / "non_existing_dir") # read from a bad directory + + +def test_entity_ruler_remove_basic(nlp): + ruler = EntityRuler(nlp) + patterns = [ + {"label": "PERSON", "pattern": "Duygu", "id": "duygu"}, + {"label": "ORG", "pattern": "ACME", "id": "acme"}, + {"label": "ORG", "pattern": "ACM"}, + ] + ruler.add_patterns(patterns) + doc = ruler(nlp.make_doc("Duygu went to school")) + assert len(ruler.patterns) == 3 + assert len(doc.ents) == 1 + assert doc.ents[0].label_ == "PERSON" + assert doc.ents[0].text == "Duygu" + assert "PERSON||duygu" in ruler.phrase_matcher + ruler.remove("duygu") + doc = ruler(nlp.make_doc("Duygu went to school")) + assert len(doc.ents) == 0 + assert "PERSON||duygu" not in ruler.phrase_matcher + assert len(ruler.patterns) == 2 + + +def test_entity_ruler_remove_same_id_multiple_patterns(nlp): + ruler = EntityRuler(nlp) + patterns = [ + {"label": "PERSON", "pattern": "Duygu", "id": "duygu"}, + {"label": "ORG", "pattern": "DuyguCorp", "id": "duygu"}, + {"label": "ORG", "pattern": "ACME", "id": "acme"}, + ] + ruler.add_patterns(patterns) + doc = ruler(nlp.make_doc("Duygu founded DuyguCorp and ACME.")) + assert len(ruler.patterns) == 3 + assert "PERSON||duygu" in ruler.phrase_matcher + assert "ORG||duygu" in ruler.phrase_matcher + assert len(doc.ents) == 3 + ruler.remove("duygu") + doc = ruler(nlp.make_doc("Duygu founded DuyguCorp and ACME.")) + assert len(ruler.patterns) == 1 + assert "PERSON||duygu" not in ruler.phrase_matcher + assert "ORG||duygu" not in ruler.phrase_matcher + assert len(doc.ents) == 1 + + +def test_entity_ruler_remove_nonexisting_pattern(nlp): + ruler = EntityRuler(nlp) + patterns = [ + {"label": "PERSON", "pattern": "Duygu", "id": "duygu"}, + {"label": "ORG", "pattern": "ACME", "id": "acme"}, + {"label": "ORG", "pattern": "ACM"}, + ] + ruler.add_patterns(patterns) + assert len(ruler.patterns) == 3 + with pytest.raises(ValueError): + ruler.remove("nepattern") + assert len(ruler.patterns) == 3 + + +def test_entity_ruler_remove_several_patterns(nlp): + ruler = EntityRuler(nlp) + patterns = [ + {"label": "PERSON", "pattern": "Duygu", "id": "duygu"}, + {"label": "ORG", "pattern": "ACME", "id": "acme"}, + {"label": "ORG", "pattern": "ACM"}, + ] + ruler.add_patterns(patterns) + doc = ruler(nlp.make_doc("Duygu founded her company ACME.")) + assert len(ruler.patterns) == 3 + assert len(doc.ents) == 2 + assert doc.ents[0].label_ == "PERSON" + assert doc.ents[0].text == "Duygu" + assert doc.ents[1].label_ == "ORG" + assert doc.ents[1].text == "ACME" + ruler.remove("duygu") + doc = ruler(nlp.make_doc("Duygu founded her company ACME")) + assert len(ruler.patterns) == 2 + assert len(doc.ents) == 1 + assert doc.ents[0].label_ == "ORG" + assert doc.ents[0].text == "ACME" + ruler.remove("acme") + doc = ruler(nlp.make_doc("Duygu founded her company ACME")) + assert len(ruler.patterns) == 1 + assert len(doc.ents) == 0 + + +def test_entity_ruler_remove_patterns_in_a_row(nlp): + ruler = EntityRuler(nlp) + patterns = [ + {"label": "PERSON", "pattern": "Duygu", "id": "duygu"}, + {"label": "ORG", "pattern": "ACME", "id": "acme"}, + {"label": "DATE", "pattern": "her birthday", "id": "bday"}, + {"label": "ORG", "pattern": "ACM"}, + ] + ruler.add_patterns(patterns) + doc = ruler(nlp.make_doc("Duygu founded her company ACME on her birthday")) + assert len(doc.ents) == 3 + assert doc.ents[0].label_ == "PERSON" + assert doc.ents[0].text == "Duygu" + assert doc.ents[1].label_ == "ORG" + assert doc.ents[1].text == "ACME" + assert doc.ents[2].label_ == "DATE" + assert doc.ents[2].text == "her birthday" + ruler.remove("duygu") + ruler.remove("acme") + ruler.remove("bday") + doc = ruler(nlp.make_doc("Duygu went to school")) + assert len(doc.ents) == 0 + + +def test_entity_ruler_remove_all_patterns(nlp): + ruler = EntityRuler(nlp) + patterns = [ + {"label": "PERSON", "pattern": "Duygu", "id": "duygu"}, + {"label": "ORG", "pattern": "ACME", "id": "acme"}, + {"label": "DATE", "pattern": "her birthday", "id": "bday"}, + ] + ruler.add_patterns(patterns) + assert len(ruler.patterns) == 3 + ruler.remove("duygu") + assert len(ruler.patterns) == 2 + ruler.remove("acme") + assert len(ruler.patterns) == 1 + ruler.remove("bday") + assert len(ruler.patterns) == 0 + with pytest.warns(UserWarning): + doc = ruler(nlp.make_doc("Duygu founded her company ACME on her birthday")) + assert len(doc.ents) == 0 + + +def test_entity_ruler_remove_and_add(nlp): + ruler = EntityRuler(nlp) + patterns = [{"label": "DATE", "pattern": "last time"}] + ruler.add_patterns(patterns) + doc = ruler( + nlp.make_doc("I saw him last time we met, this time he brought some flowers") + ) + assert len(ruler.patterns) == 1 + assert len(doc.ents) == 1 + assert doc.ents[0].label_ == "DATE" + assert doc.ents[0].text == "last time" + patterns1 = [{"label": "DATE", "pattern": "this time", "id": "ttime"}] + ruler.add_patterns(patterns1) + doc = ruler( + nlp.make_doc("I saw him last time we met, this time he brought some flowers") + ) + assert len(ruler.patterns) == 2 + assert len(doc.ents) == 2 + assert doc.ents[0].label_ == "DATE" + assert doc.ents[0].text == "last time" + assert doc.ents[1].label_ == "DATE" + assert doc.ents[1].text == "this time" + ruler.remove("ttime") + doc = ruler( + nlp.make_doc("I saw him last time we met, this time he brought some flowers") + ) + assert len(ruler.patterns) == 1 + assert len(doc.ents) == 1 + assert doc.ents[0].label_ == "DATE" + assert doc.ents[0].text == "last time" + ruler.add_patterns(patterns1) + doc = ruler( + nlp.make_doc("I saw him last time we met, this time he brought some flowers") + ) + assert len(ruler.patterns) == 2 + assert len(doc.ents) == 2 + patterns2 = [{"label": "DATE", "pattern": "another time", "id": "ttime"}] + ruler.add_patterns(patterns2) + doc = ruler( + nlp.make_doc( + "I saw him last time we met, this time he brought some flowers, another time some chocolate." + ) + ) + assert len(ruler.patterns) == 3 + assert len(doc.ents) == 3 + ruler.remove("ttime") + doc = ruler( + nlp.make_doc( + "I saw him last time we met, this time he brought some flowers, another time some chocolate." + ) + ) + assert len(ruler.patterns) == 1 + assert len(doc.ents) == 1 diff --git a/website/docs/api/entityruler.md b/website/docs/api/entityruler.md index fb33642f8..6d8f835bf 100644 --- a/website/docs/api/entityruler.md +++ b/website/docs/api/entityruler.md @@ -210,6 +210,24 @@ of dicts) or a phrase pattern (string). For more details, see the usage guide on | ---------- | ---------------------------------------------------------------- | | `patterns` | The patterns to add. ~~List[Dict[str, Union[str, List[dict]]]]~~ | + +## EntityRuler.remove {#remove tag="method" new="3.2.1"} + +Remove a pattern by its ID from the entity ruler. A `ValueError` is raised if the ID does not exist. + +> #### Example +> +> ```python +> patterns = [{"label": "ORG", "pattern": "Apple", "id": "apple"}] +> ruler = nlp.add_pipe("entity_ruler") +> ruler.add_patterns(patterns) +> ruler.remove("apple") +> ``` + +| Name | Description | +| ---------- | ---------------------------------------------------------------- | +| `id` | The ID of the pattern rule. ~~str~~ | + ## EntityRuler.to_disk {#to_disk tag="method"} Save the entity ruler patterns to a directory. The patterns will be saved as