mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-29 11:26:28 +03:00
162 lines
8.6 KiB
Python
162 lines
8.6 KiB
Python
|
import pytest
|
||
|
|
||
|
from spacy.tokens import Span, SpanGroup
|
||
|
from spacy.tokens._dict_proxies import SpanGroups
|
||
|
|
||
|
|
||
|
@pytest.mark.issue(10685)
|
||
|
def test_issue10685(en_tokenizer):
|
||
|
"""Test `SpanGroups` de/serialization"""
|
||
|
# Start with a Doc with no SpanGroups
|
||
|
doc = en_tokenizer("Will it blend?")
|
||
|
|
||
|
# Test empty `SpanGroups` de/serialization:
|
||
|
assert len(doc.spans) == 0
|
||
|
doc.spans.from_bytes(doc.spans.to_bytes())
|
||
|
assert len(doc.spans) == 0
|
||
|
|
||
|
# Test non-empty `SpanGroups` de/serialization:
|
||
|
doc.spans["test"] = SpanGroup(doc, name="test", spans=[doc[0:1]])
|
||
|
doc.spans["test2"] = SpanGroup(doc, name="test", spans=[doc[1:2]])
|
||
|
|
||
|
def assert_spangroups():
|
||
|
assert len(doc.spans) == 2
|
||
|
assert doc.spans["test"].name == "test"
|
||
|
assert doc.spans["test2"].name == "test"
|
||
|
assert list(doc.spans["test"]) == [doc[0:1]]
|
||
|
assert list(doc.spans["test2"]) == [doc[1:2]]
|
||
|
|
||
|
# Sanity check the currently-expected behavior
|
||
|
assert_spangroups()
|
||
|
|
||
|
# Now test serialization/deserialization:
|
||
|
doc.spans.from_bytes(doc.spans.to_bytes())
|
||
|
|
||
|
assert_spangroups()
|
||
|
|
||
|
|
||
|
def test_span_groups_serialization_mismatches(en_tokenizer):
|
||
|
"""Test the serialization of multiple mismatching `SpanGroups` keys and `SpanGroup.name`s"""
|
||
|
doc = en_tokenizer("How now, brown cow?")
|
||
|
# Some variety:
|
||
|
# 1 SpanGroup where its name matches its key
|
||
|
# 2 SpanGroups that have the same name--which is not a key
|
||
|
# 2 SpanGroups that have the same name--which is a key
|
||
|
# 1 SpanGroup that is a value for 2 different keys (where its name is a key)
|
||
|
# 1 SpanGroup that is a value for 2 different keys (where its name is not a key)
|
||
|
groups = doc.spans
|
||
|
groups["key1"] = SpanGroup(doc, name="key1", spans=[doc[0:1], doc[1:2]])
|
||
|
groups["key2"] = SpanGroup(doc, name="too", spans=[doc[3:4], doc[4:5]])
|
||
|
groups["key3"] = SpanGroup(doc, name="too", spans=[doc[1:2], doc[0:1]])
|
||
|
groups["key4"] = SpanGroup(doc, name="key4", spans=[doc[0:1]])
|
||
|
groups["key5"] = SpanGroup(doc, name="key4", spans=[doc[0:1]])
|
||
|
sg6 = SpanGroup(doc, name="key6", spans=[doc[0:1]])
|
||
|
groups["key6"] = sg6
|
||
|
groups["key7"] = sg6
|
||
|
sg8 = SpanGroup(doc, name="also", spans=[doc[1:2]])
|
||
|
groups["key8"] = sg8
|
||
|
groups["key9"] = sg8
|
||
|
|
||
|
regroups = SpanGroups(doc).from_bytes(groups.to_bytes())
|
||
|
|
||
|
# Assert regroups == groups
|
||
|
assert regroups.keys() == groups.keys()
|
||
|
for key, regroup in regroups.items():
|
||
|
# Assert regroup == groups[key]
|
||
|
assert regroup.name == groups[key].name
|
||
|
assert list(regroup) == list(groups[key])
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"spans_bytes,doc_text,expected_spangroups,expected_warning",
|
||
|
# The bytestrings below were generated from an earlier version of spaCy
|
||
|
# that serialized `SpanGroups` as a list of SpanGroup bytes (via SpanGroups.to_bytes).
|
||
|
# Comments preceding the bytestrings indicate from what Doc they were created.
|
||
|
[
|
||
|
# Empty SpanGroups:
|
||
|
(b"\x90", "", {}, False),
|
||
|
# doc = nlp("Will it blend?")
|
||
|
# doc.spans['test'] = SpanGroup(doc, name='test', spans=[doc[0:1]])
|
||
|
(
|
||
|
b"\x91\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04",
|
||
|
"Will it blend?",
|
||
|
{"test": {"name": "test", "spans": [(0, 1)]}},
|
||
|
False,
|
||
|
),
|
||
|
# doc = nlp("Will it blend?")
|
||
|
# doc.spans['test'] = SpanGroup(doc, name='test', spans=[doc[0:1]])
|
||
|
# doc.spans['test2'] = SpanGroup(doc, name='test', spans=[doc[1:2]])
|
||
|
(
|
||
|
b"\x92\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x07",
|
||
|
"Will it blend?",
|
||
|
# We expect only 1 SpanGroup to be in doc.spans in this example
|
||
|
# because there are 2 `SpanGroup`s that have the same .name. See #10685.
|
||
|
{"test": {"name": "test", "spans": [(1, 2)]}},
|
||
|
True,
|
||
|
),
|
||
|
# doc = nlp('How now, brown cow?')
|
||
|
# doc.spans['key1'] = SpanGroup(doc, name='key1', spans=[doc[0:1], doc[1:2]])
|
||
|
# doc.spans['key2'] = SpanGroup(doc, name='too', spans=[doc[3:4], doc[4:5]])
|
||
|
# doc.spans['key3'] = SpanGroup(doc, name='too', spans=[doc[1:2], doc[0:1]])
|
||
|
# doc.spans['key4'] = SpanGroup(doc, name='key4', spans=[doc[0:1]])
|
||
|
# doc.spans['key5'] = SpanGroup(doc, name='key4', spans=[doc[0:1]])
|
||
|
(
|
||
|
b"\x95\xc4m\x83\xa4name\xa4key1\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x07\xc4l\x83\xa4name\xa3too\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\t\x00\x00\x00\x0e\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x05\x00\x00\x00\x0f\x00\x00\x00\x12\xc4l\x83\xa4name\xa3too\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x07\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4C\x83\xa4name\xa4key4\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4C\x83\xa4name\xa4key4\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03",
|
||
|
"How now, brown cow?",
|
||
|
{
|
||
|
"key1": {"name": "key1", "spans": [(0, 1), (1, 2)]},
|
||
|
"too": {"name": "too", "spans": [(1, 2), (0, 1)]},
|
||
|
"key4": {"name": "key4", "spans": [(0, 1)]},
|
||
|
},
|
||
|
True,
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
def test_deserialize_span_groups_compat(
|
||
|
en_tokenizer, spans_bytes, doc_text, expected_spangroups, expected_warning
|
||
|
):
|
||
|
"""Test backwards-compatibility of `SpanGroups` deserialization.
|
||
|
This uses serializations (bytes) from a prior version of spaCy (before 3.3.1).
|
||
|
|
||
|
spans_bytes (bytes): Serialized `SpanGroups` object.
|
||
|
doc_text (str): Doc text.
|
||
|
expected_spangroups (dict):
|
||
|
Dict mapping every expected (after deserialization) `SpanGroups` key
|
||
|
to a SpanGroup's "args", where a SpanGroup's args are given as a dict:
|
||
|
{"name": span_group.name,
|
||
|
"spans": [(span0.start, span0.end), ...]}
|
||
|
expected_warning (bool): Whether a warning is to be expected from .from_bytes()
|
||
|
--i.e. if more than 1 SpanGroup has the same .name within the `SpanGroups`.
|
||
|
"""
|
||
|
doc = en_tokenizer(doc_text)
|
||
|
|
||
|
if expected_warning:
|
||
|
with pytest.warns(UserWarning):
|
||
|
doc.spans.from_bytes(spans_bytes)
|
||
|
else:
|
||
|
# TODO: explicitly check for lack of a warning
|
||
|
doc.spans.from_bytes(spans_bytes)
|
||
|
|
||
|
assert doc.spans.keys() == expected_spangroups.keys()
|
||
|
for name, spangroup_args in expected_spangroups.items():
|
||
|
assert doc.spans[name].name == spangroup_args["name"]
|
||
|
spans = [Span(doc, start, end) for start, end in spangroup_args["spans"]]
|
||
|
assert list(doc.spans[name]) == spans
|
||
|
|
||
|
|
||
|
def test_span_groups_serialization(en_tokenizer):
|
||
|
doc = en_tokenizer("0 1 2 3 4 5 6")
|
||
|
span_groups = SpanGroups(doc)
|
||
|
spans = [doc[0:2], doc[1:3]]
|
||
|
sg1 = SpanGroup(doc, spans=spans)
|
||
|
span_groups["key1"] = sg1
|
||
|
span_groups["key2"] = sg1
|
||
|
span_groups["key3"] = []
|
||
|
reloaded_span_groups = SpanGroups(doc).from_bytes(span_groups.to_bytes())
|
||
|
assert span_groups.keys() == reloaded_span_groups.keys()
|
||
|
for key, value in span_groups.items():
|
||
|
assert all(
|
||
|
span == reloaded_span
|
||
|
for span, reloaded_span in zip(span_groups[key], reloaded_span_groups[key])
|
||
|
)
|