mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Fix ent_ids and labels properties when id attribute used in patterns (#4900)
* Fix ent_ids and labels properties when id attribute used in patterns * use set for labels * sort end_ids for comparison in entity_ruler tests * fixing entity_ruler ent_ids test * add to set
This commit is contained in:
parent
fbfc418745
commit
b9afcd56e3
|
@ -129,20 +129,31 @@ class EntityRuler(object):
|
|||
|
||||
DOCS: https://spacy.io/api/entityruler#labels
|
||||
"""
|
||||
all_labels = set(self.token_patterns.keys())
|
||||
all_labels.update(self.phrase_patterns.keys())
|
||||
keys = set(self.token_patterns.keys())
|
||||
keys.update(self.phrase_patterns.keys())
|
||||
all_labels = set()
|
||||
|
||||
for l in keys:
|
||||
if self.ent_id_sep in l:
|
||||
label, _ = self._split_label(l)
|
||||
all_labels.add(label)
|
||||
else:
|
||||
all_labels.add(l)
|
||||
return tuple(all_labels)
|
||||
|
||||
@property
|
||||
def ent_ids(self):
|
||||
"""All entity ids present in the match patterns `id` properties.
|
||||
"""All entity ids present in the match patterns `id` properties
|
||||
|
||||
RETURNS (set): The string entity ids.
|
||||
|
||||
DOCS: https://spacy.io/api/entityruler#ent_ids
|
||||
"""
|
||||
keys = set(self.token_patterns.keys())
|
||||
keys.update(self.phrase_patterns.keys())
|
||||
all_ent_ids = set()
|
||||
for l in self.labels:
|
||||
|
||||
for l in keys:
|
||||
if self.ent_id_sep in l:
|
||||
_, ent_id = self._split_label(l)
|
||||
all_ent_ids.add(ent_id)
|
||||
|
|
|
@ -21,6 +21,7 @@ def patterns():
|
|||
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
|
||||
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
|
||||
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
|
||||
{"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"},
|
||||
]
|
||||
|
||||
|
||||
|
@ -147,3 +148,14 @@ def test_entity_ruler_validate(nlp):
|
|||
# invalid pattern raises error with validate
|
||||
with pytest.raises(MatchPatternError):
|
||||
validated_ruler.add_patterns([invalid_pattern])
|
||||
|
||||
|
||||
def test_entity_ruler_properties(nlp, patterns):
|
||||
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
||||
assert sorted(ruler.labels) == sorted([
|
||||
"HELLO",
|
||||
"BYE",
|
||||
"COMPLEX",
|
||||
"TECH_ORG"
|
||||
])
|
||||
assert sorted(ruler.ent_ids) == ["a1", "a2"]
|
||||
|
|
Loading…
Reference in New Issue
Block a user