diff --git a/examples/information_extraction/entity_relations.py b/examples/information_extraction/entity_relations.py index 1b3ba1d27..c40a3c10d 100644 --- a/examples/information_extraction/entity_relations.py +++ b/examples/information_extraction/entity_relations.py @@ -7,7 +7,7 @@ dependency tree to find the noun phrase they are referring to – for example: $9.4 million --> Net income. Compatible with: spaCy v2.0.0+ -Last tested with: v2.1.0 +Last tested with: v2.2.1 """ from __future__ import unicode_literals, print_function @@ -38,14 +38,17 @@ def main(model="en_core_web_sm"): def filter_spans(spans): # Filter a sequence of spans so they don't contain overlaps - get_sort_key = lambda span: (span.end - span.start, span.start) + # For spaCy 2.1.4+: this function is available as spacy.util.filter_spans() + get_sort_key = lambda span: (span.end - span.start, -span.start) sorted_spans = sorted(spans, key=get_sort_key, reverse=True) result = [] seen_tokens = set() for span in sorted_spans: + # Check for end - 1 here because boundaries are inclusive if span.start not in seen_tokens and span.end - 1 not in seen_tokens: result.append(span) - seen_tokens.update(range(span.start, span.end)) + seen_tokens.update(range(span.start, span.end)) + result = sorted(result, key=lambda span: span.start) return result diff --git a/spacy/tests/doc/test_span.py b/spacy/tests/doc/test_span.py index c8c809d24..f813a9743 100644 --- a/spacy/tests/doc/test_span.py +++ b/spacy/tests/doc/test_span.py @@ -253,3 +253,11 @@ def test_filter_spans(doc): assert len(filtered[1]) == 5 assert filtered[0].start == 1 and filtered[0].end == 4 assert filtered[1].start == 5 and filtered[1].end == 10 + # Test filtering overlaps with earlier preference for identical length + spans = [doc[1:4], doc[2:5], doc[5:10], doc[7:9], doc[1:4]] + filtered = filter_spans(spans) + assert len(filtered) == 2 + assert len(filtered[0]) == 3 + assert len(filtered[1]) == 5 + assert filtered[0].start == 1 and filtered[0].end == 4 + assert filtered[1].start == 5 and filtered[1].end == 10 diff --git a/spacy/util.py b/spacy/util.py index 9798ff11b..fa8111d67 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -666,7 +666,7 @@ def filter_spans(spans): spans (iterable): The spans to filter. RETURNS (list): The filtered spans. """ - get_sort_key = lambda span: (span.end - span.start, span.start) + get_sort_key = lambda span: (span.end - span.start, -span.start) sorted_spans = sorted(spans, key=get_sort_key, reverse=True) result = [] seen_tokens = set()