Prefer earlier spans in EntityRuler (#5843)

Similar to #4414, update the sorting in EntityRuler to prefer the first
span in overlapping spans.
This commit is contained in:
Adriane Boyd 2020-07-31 16:09:32 +02:00 committed by GitHub
parent d16c0f2c3a
commit ac14ce7c30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 1 deletions

View File

@ -95,7 +95,7 @@ class EntityRuler(object):
matches = set( matches = set(
[(m_id, start, end) for m_id, start, end in matches if start != end] [(m_id, start, end) for m_id, start, end in matches if start != end]
) )
get_sort_key = lambda m: (m[2] - m[1], m[1]) get_sort_key = lambda m: (m[2] - m[1], -m[1])
matches = sorted(matches, key=get_sort_key, reverse=True) matches = sorted(matches, key=get_sort_key, reverse=True)
entities = list(doc.ents) entities = list(doc.ents)
new_entities = [] new_entities = []

View File

@ -154,3 +154,15 @@ def test_entity_ruler_properties(nlp, patterns):
ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True) ruler = EntityRuler(nlp, patterns=patterns, overwrite_ents=True)
assert sorted(ruler.labels) == sorted(["HELLO", "BYE", "COMPLEX", "TECH_ORG"]) assert sorted(ruler.labels) == sorted(["HELLO", "BYE", "COMPLEX", "TECH_ORG"])
assert sorted(ruler.ent_ids) == ["a1", "a2"] assert sorted(ruler.ent_ids) == ["a1", "a2"]
def test_entity_ruler_overlapping_spans(nlp):
ruler = EntityRuler(nlp)
patterns = [
{"label": "FOOBAR", "pattern": "foo bar"},
{"label": "BARBAZ", "pattern": "bar baz"},
]
ruler.add_patterns(patterns)
doc = ruler(nlp.make_doc("foo bar baz"))
assert len(doc.ents) == 1
assert doc.ents[0].label_ == "FOOBAR"