mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix util.filter_spans() to prefer first span in overlapping sam… (#4414)
* Update util.filter_spans() to prefer earlier spans * Add filter_spans test for first same-length span * Update entity relation example to refer to util.filter_spans()
This commit is contained in:
parent
da6e0de34f
commit
6f54e59fe7
|
@ -7,7 +7,7 @@ dependency tree to find the noun phrase they are referring to – for example:
|
||||||
$9.4 million --> Net income.
|
$9.4 million --> Net income.
|
||||||
|
|
||||||
Compatible with: spaCy v2.0.0+
|
Compatible with: spaCy v2.0.0+
|
||||||
Last tested with: v2.1.0
|
Last tested with: v2.2.1
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals, print_function
|
from __future__ import unicode_literals, print_function
|
||||||
|
|
||||||
|
@ -38,14 +38,17 @@ def main(model="en_core_web_sm"):
|
||||||
|
|
||||||
def filter_spans(spans):
|
def filter_spans(spans):
|
||||||
# Filter a sequence of spans so they don't contain overlaps
|
# Filter a sequence of spans so they don't contain overlaps
|
||||||
get_sort_key = lambda span: (span.end - span.start, span.start)
|
# For spaCy 2.1.4+: this function is available as spacy.util.filter_spans()
|
||||||
|
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
||||||
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||||
result = []
|
result = []
|
||||||
seen_tokens = set()
|
seen_tokens = set()
|
||||||
for span in sorted_spans:
|
for span in sorted_spans:
|
||||||
|
# Check for end - 1 here because boundaries are inclusive
|
||||||
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
|
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
|
||||||
result.append(span)
|
result.append(span)
|
||||||
seen_tokens.update(range(span.start, span.end))
|
seen_tokens.update(range(span.start, span.end))
|
||||||
|
result = sorted(result, key=lambda span: span.start)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -253,3 +253,11 @@ def test_filter_spans(doc):
|
||||||
assert len(filtered[1]) == 5
|
assert len(filtered[1]) == 5
|
||||||
assert filtered[0].start == 1 and filtered[0].end == 4
|
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||||
assert filtered[1].start == 5 and filtered[1].end == 10
|
assert filtered[1].start == 5 and filtered[1].end == 10
|
||||||
|
# Test filtering overlaps with earlier preference for identical length
|
||||||
|
spans = [doc[1:4], doc[2:5], doc[5:10], doc[7:9], doc[1:4]]
|
||||||
|
filtered = filter_spans(spans)
|
||||||
|
assert len(filtered) == 2
|
||||||
|
assert len(filtered[0]) == 3
|
||||||
|
assert len(filtered[1]) == 5
|
||||||
|
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||||
|
assert filtered[1].start == 5 and filtered[1].end == 10
|
||||||
|
|
|
@ -666,7 +666,7 @@ def filter_spans(spans):
|
||||||
spans (iterable): The spans to filter.
|
spans (iterable): The spans to filter.
|
||||||
RETURNS (list): The filtered spans.
|
RETURNS (list): The filtered spans.
|
||||||
"""
|
"""
|
||||||
get_sort_key = lambda span: (span.end - span.start, span.start)
|
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
||||||
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||||
result = []
|
result = []
|
||||||
seen_tokens = set()
|
seen_tokens = set()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user