mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 02:36:32 +03:00
Prefer earlier spans in EntityRuler (#5843)
Similar to #4414, update the sorting in EntityRuler to prefer the first span in overlapping spans.
This commit is contained in:
parent
d16c0f2c3a
commit
ac14ce7c30
|
@ -95,7 +95,7 @@ class EntityRuler(object):
|
||||||
matches = set(
|
matches = set(
|
||||||
[(m_id, start, end) for m_id, start, end in matches if start != end]
|
[(m_id, start, end) for m_id, start, end in matches if start != end]
|
||||||
)
|
)
|
||||||
get_sort_key = lambda m: (m[2] - m[1], m[1])
|
get_sort_key = lambda m: (m[2] - m[1], -m[1])
|
||||||
matches = sorted(matches, key=get_sort_key, reverse=True)
|
matches = sorted(matches, key=get_sort_key, reverse=True)
|
||||||
entities = list(doc.ents)
|
entities = list(doc.ents)
|
||||||
new_entities = []
|
new_entities = []
|
||||||
|
|
|
@ -154,3 +154,15 @@ def test_entity_ruler_properties(nlp, patterns):
|
||||||
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
|
||||||
assert sorted(ruler.labels) == sorted(["HELLO", "BYE", "COMPLEX", "TECH_ORG"])
|
assert sorted(ruler.labels) == sorted(["HELLO", "BYE", "COMPLEX", "TECH_ORG"])
|
||||||
assert sorted(ruler.ent_ids) == ["a1", "a2"]
|
assert sorted(ruler.ent_ids) == ["a1", "a2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_entity_ruler_overlapping_spans(nlp):
|
||||||
|
ruler = EntityRuler(nlp)
|
||||||
|
patterns = [
|
||||||
|
{"label": "FOOBAR", "pattern": "foo bar"},
|
||||||
|
{"label": "BARBAZ", "pattern": "bar baz"},
|
||||||
|
]
|
||||||
|
ruler.add_patterns(patterns)
|
||||||
|
doc = ruler(nlp.make_doc("foo bar baz"))
|
||||||
|
assert len(doc.ents) == 1
|
||||||
|
assert doc.ents[0].label_ == "FOOBAR"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user