mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-28 21:03:41 +03:00
Fix ent_ids and labels properties when id attribute used in patterns (#4900)
* Fix ent_ids and labels properties when id attribute used in patterns * use set for labels * sort end_ids for comparison in entity_ruler tests * fixing entity_ruler ent_ids test * add to set
This commit is contained in:
parent
fbfc418745
commit
b9afcd56e3
|
@ -129,20 +129,31 @@ class EntityRuler(object):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entityruler#labels
|
DOCS: https://spacy.io/api/entityruler#labels
|
||||||
"""
|
"""
|
||||||
all_labels = set(self.token_patterns.keys())
|
keys = set(self.token_patterns.keys())
|
||||||
all_labels.update(self.phrase_patterns.keys())
|
keys.update(self.phrase_patterns.keys())
|
||||||
|
all_labels = set()
|
||||||
|
|
||||||
|
for l in keys:
|
||||||
|
if self.ent_id_sep in l:
|
||||||
|
label, _ = self._split_label(l)
|
||||||
|
all_labels.add(label)
|
||||||
|
else:
|
||||||
|
all_labels.add(l)
|
||||||
return tuple(all_labels)
|
return tuple(all_labels)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ent_ids(self):
|
def ent_ids(self):
|
||||||
"""All entity ids present in the match patterns `id` properties.
|
"""All entity ids present in the match patterns `id` properties
|
||||||
|
|
||||||
RETURNS (set): The string entity ids.
|
RETURNS (set): The string entity ids.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entityruler#ent_ids
|
DOCS: https://spacy.io/api/entityruler#ent_ids
|
||||||
"""
|
"""
|
||||||
|
keys = set(self.token_patterns.keys())
|
||||||
|
keys.update(self.phrase_patterns.keys())
|
||||||
all_ent_ids = set()
|
all_ent_ids = set()
|
||||||
for l in self.labels:
|
|
||||||
|
for l in keys:
|
||||||
if self.ent_id_sep in l:
|
if self.ent_id_sep in l:
|
||||||
_, ent_id = self._split_label(l)
|
_, ent_id = self._split_label(l)
|
||||||
all_ent_ids.add(ent_id)
|
all_ent_ids.add(ent_id)
|
||||||
|
|
|
@ -21,6 +21,7 @@ def patterns():
|
||||||
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
|
{"label": "HELLO", "pattern": [{"ORTH": "HELLO"}]},
|
||||||
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
|
{"label": "COMPLEX", "pattern": [{"ORTH": "foo", "OP": "*"}]},
|
||||||
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
|
{"label": "TECH_ORG", "pattern": "Apple", "id": "a1"},
|
||||||
|
{"label": "TECH_ORG", "pattern": "Microsoft", "id": "a2"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,3 +148,14 @@ def test_entity_ruler_validate(nlp):
|
||||||
# invalid pattern raises error with validate
|
# invalid pattern raises error with validate
|
||||||
with pytest.raises(MatchPatternError):
|
with pytest.raises(MatchPatternError):
|
||||||
validated_ruler.add_patterns([invalid_pattern])
|
validated_ruler.add_patterns([invalid_pattern])
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_ruler_properties(nlp, patterns):
|
||||||
|
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
||||||
|
assert sorted(ruler.labels) == sorted([
|
||||||
|
"HELLO",
|
||||||
|
"BYE",
|
||||||
|
"COMPLEX",
|
||||||
|
"TECH_ORG"
|
||||||
|
])
|
||||||
|
assert sorted(ruler.ent_ids) == ["a1", "a2"]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user