mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Fix util.filter_spans() to prefer first span in overlapping sam… (#4414)
* Update util.filter_spans() to prefer earlier spans * Add filter_spans test for first same-length span * Update entity relation example to refer to util.filter_spans()
This commit is contained in:
parent
da6e0de34f
commit
6f54e59fe7
|
@ -7,7 +7,7 @@ dependency tree to find the noun phrase they are referring to – for example:
|
|||
$9.4 million --> Net income.
|
||||
|
||||
Compatible with: spaCy v2.0.0+
|
||||
Last tested with: v2.1.0
|
||||
Last tested with: v2.2.1
|
||||
"""
|
||||
from __future__ import unicode_literals, print_function
|
||||
|
||||
|
@ -38,14 +38,17 @@ def main(model="en_core_web_sm"):
|
|||
|
||||
def filter_spans(spans):
|
||||
# Filter a sequence of spans so they don't contain overlaps
|
||||
get_sort_key = lambda span: (span.end - span.start, span.start)
|
||||
# For spaCy 2.1.4+: this function is available as spacy.util.filter_spans()
|
||||
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
||||
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||
result = []
|
||||
seen_tokens = set()
|
||||
for span in sorted_spans:
|
||||
# Check for end - 1 here because boundaries are inclusive
|
||||
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
|
||||
result.append(span)
|
||||
seen_tokens.update(range(span.start, span.end))
|
||||
result = sorted(result, key=lambda span: span.start)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -253,3 +253,11 @@ def test_filter_spans(doc):
|
|||
assert len(filtered[1]) == 5
|
||||
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||
assert filtered[1].start == 5 and filtered[1].end == 10
|
||||
# Test filtering overlaps with earlier preference for identical length
|
||||
spans = [doc[1:4], doc[2:5], doc[5:10], doc[7:9], doc[1:4]]
|
||||
filtered = filter_spans(spans)
|
||||
assert len(filtered) == 2
|
||||
assert len(filtered[0]) == 3
|
||||
assert len(filtered[1]) == 5
|
||||
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||
assert filtered[1].start == 5 and filtered[1].end == 10
|
||||
|
|
|
@ -666,7 +666,7 @@ def filter_spans(spans):
|
|||
spans (iterable): The spans to filter.
|
||||
RETURNS (list): The filtered spans.
|
||||
"""
|
||||
get_sort_key = lambda span: (span.end - span.start, span.start)
|
||||
get_sort_key = lambda span: (span.end - span.start, -span.start)
|
||||
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
|
||||
result = []
|
||||
seen_tokens = set()
|
||||
|
|
Loading…
Reference in New Issue
Block a user