Fix ent_ids and labels properties when id attribute used in patterns (#4900)

* Fix ent_ids and labels properties when id attribute used in patterns

* use set for labels

* sort end_ids for comparison in entity_ruler tests

* fixing entity_ruler ent_ids test

* add to set
This commit is contained in:
Kabir Khan 2020-01-15 17:01:31 -08:00 committed by Ines Montani
parent fbfc418745
commit b9afcd56e3
2 changed files with 27 additions and 4 deletions

View File

@ -129,20 +129,31 @@ class EntityRuler(object):
DOCS: https://spacy.io/api/entityruler#labels
"""
all_labels = set(self.token_patterns.keys())
all_labels.update(self.phrase_patterns.keys())
keys = set(self.token_patterns.keys())
keys.update(self.phrase_patterns.keys())
all_labels = set()
for l in keys:
if self.ent_id_sep in l:
label, _ = self._split_label(l)
all_labels.add(label)
else:
all_labels.add(l)
return tuple(all_labels)
@property
def ent_ids(self):
"""All entity ids present in the match patterns `id` properties.
"""All entity ids present in the match patterns `id` properties
RETURNS (set): The string entity ids.
DOCS: https://spacy.io/api/entityruler#ent_ids
"""
keys = set(self.token_patterns.keys())
keys.update(self.phrase_patterns.keys())
all_ent_ids = set()
for l in self.labels:
for l in keys:
if self.ent_id_sep in l:
_, ent_id = self._split_label(l)
all_ent_ids.add(ent_id)

View File

@ -21,6 +21,7 @@ def patterns():
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
{"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"},
]
@ -147,3 +148,14 @@ def test_entity_ruler_validate(nlp):
# invalid pattern raises error with validate
with pytest.raises(MatchPatternError):
validated_ruler.add_patterns([invalid_pattern])
def test_entity_ruler_properties(nlp, patterns):
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
assert sorted(ruler.labels) == sorted([
"HELLO",
"BYE",
"COMPLEX",
"TECH_ORG"
])
assert sorted(ruler.ent_ids) == ["a1", "a2"]