Fix util.filter_spans() to prefer first span in overlapping sam… (#4414)

* Update util.filter_spans() to prefer earlier spans

* Add filter_spans test for first same-length span

* Update entity relation example to refer to util.filter_spans()
This commit is contained in:
adrianeboyd 2019-10-10 17:00:03 +02:00 committed by Ines Montani
parent da6e0de34f
commit 6f54e59fe7
3 changed files with 15 additions and 4 deletions

View File

@ -7,7 +7,7 @@ dependency tree to find the noun phrase they are referring to for example:
$9.4 million --> Net income.
Compatible with: spaCy v2.0.0+
Last tested with: v2.1.0
Last tested with: v2.2.1
"""
from __future__ import unicode_literals, print_function
@ -38,14 +38,17 @@ def main(model="en_core_web_sm"):
def filter_spans(spans):
# Filter a sequence of spans so they don't contain overlaps
get_sort_key = lambda span: (span.end - span.start, span.start)
# For spaCy 2.1.4+: this function is available as spacy.util.filter_spans()
get_sort_key = lambda span: (span.end - span.start, -span.start)
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
result = []
seen_tokens = set()
for span in sorted_spans:
# Check for end - 1 here because boundaries are inclusive
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
result.append(span)
seen_tokens.update(range(span.start, span.end))
seen_tokens.update(range(span.start, span.end))
result = sorted(result, key=lambda span: span.start)
return result

View File

@ -253,3 +253,11 @@ def test_filter_spans(doc):
assert len(filtered[1]) == 5
assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 5 and filtered[1].end == 10
# Test filtering overlaps with earlier preference for identical length
spans = [doc[1:4], doc[2:5], doc[5:10], doc[7:9], doc[1:4]]
filtered = filter_spans(spans)
assert len(filtered) == 2
assert len(filtered[0]) == 3
assert len(filtered[1]) == 5
assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 5 and filtered[1].end == 10

View File

@ -666,7 +666,7 @@ def filter_spans(spans):
spans (iterable): The spans to filter.
RETURNS (list): The filtered spans.
"""
get_sort_key = lambda span: (span.end - span.start, span.start)
get_sort_key = lambda span: (span.end - span.start, -span.start)
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
result = []
seen_tokens = set()