From b9afcd56e3532125ec15f7d1f0825608c04835e3 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Wed, 15 Jan 2020 17:01:31 -0800 Subject: [PATCH] Fix ent_ids and labels properties when id attribute used in patterns (#4900) * Fix ent_ids and labels properties when id attribute used in patterns * use set for labels * sort end_ids for comparison in entity_ruler tests * fixing entity_ruler ent_ids test * add to set --- spacy/pipeline/entityruler.py | 19 +++++++++++++++---- spacy/tests/pipeline/test_entity_ruler.py | 12 ++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 2db312d64..1c8429049 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -129,20 +129,31 @@ class EntityRuler(object): DOCS: https://spacy.io/api/entityruler#labels """ - all_labels = set(self.token_patterns.keys()) - all_labels.update(self.phrase_patterns.keys()) + keys = set(self.token_patterns.keys()) + keys.update(self.phrase_patterns.keys()) + all_labels = set() + + for l in keys: + if self.ent_id_sep in l: + label, _ = self._split_label(l) + all_labels.add(label) + else: + all_labels.add(l) return tuple(all_labels) @property def ent_ids(self): - """All entity ids present in the match patterns `id` properties. + """All entity ids present in the match patterns `id` properties RETURNS (set): The string entity ids. DOCS: https://spacy.io/api/entityruler#ent_ids """ + keys = set(self.token_patterns.keys()) + keys.update(self.phrase_patterns.keys()) all_ent_ids = set() - for l in self.labels: + + for l in keys: if self.ent_id_sep in l: _, ent_id = self._split_label(l) all_ent_ids.add(ent_id) diff --git a/spacy/tests/pipeline/test_entity_ruler.py b/spacy/tests/pipeline/test_entity_ruler.py index 660ad3b28..3b46baa9b 100644 --- a/spacy/tests/pipeline/test_entity_ruler.py +++ b/spacy/tests/pipeline/test_entity_ruler.py @@ -21,6 +21,7 @@ def patterns(): {"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]}, {"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]}, {"label": "TECH_ORG", "pattern": "Apple", "id": "a1"}, + {"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"}, ] @@ -147,3 +148,14 @@ def test_entity_ruler_validate(nlp): # invalid pattern raises error with validate with pytest.raises(MatchPatternError): validated_ruler.add_patterns([invalid_pattern]) + + +def test_entity_ruler_properties(nlp, patterns): + ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True) + assert sorted(ruler.labels) == sorted([ + "HELLO", + "BYE", + "COMPLEX", + "TECH_ORG" + ]) + assert sorted(ruler.ent_ids) == ["a1", "a2"]