mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Fix: De/Serialize SpanGroups
including the SpanGroup keys (#10707)
* fix: De/Serialize `SpanGroups` including the SpanGroup keys This prevents the loss of `SpanGroup`s that have the same .name as other `SpanGroup`s within the same `SpanGroups` object (upon de/serialization of the `SpanGroups`). Fixes #10685 * Maintain backwards compatibility for serialized `SpanGroups` (serialized as: a list of `SpanGroup`s, or b'') * Add tests for `SpanGroups` deserialization backwards-compatibility * Move a `SpanGroups` de/serialization test (test_issue10685) to tests/serialize/test_serialize_spangroups.py * Output a warning if deserializing a `SpanGroups` with duplicate .name-d `SpanGroup`s * Minor refactor * `SpanGroups.from_bytes` handles only `list` and `dict` types with `dict` as the expected default * For lists, keep first rather than last value encountered * Update error message * Rename and update tests * Update to preserve list serialization of SpanGroups To avoid breaking compatibility of serialized `Doc` and `DocBin` with earlier versions of spacy v3, revert back to a list-only serialization, but update the names just for serialization so that the SpanGroups keys override the SpanGroup names. * Preserve object identity and current key overwrite * Preserve SpanGroup object identity * Preserve last rather than first span group from SpanGroup list format without SpanGroups keys * Update inline comments * Fix types * Add type info for SpanGroup.copy * Deserialize `SpanGroup`s as copies when a single SpanGroup is the value for more than 1 `SpanGroups` key. This is because we serialize `SpanGroups` as dicts (to maintain backward- and forward-compatibility) and we can't assume `SpanGroup`s with the same bytes/serialization were the same (identical) object, pre-serialization. * Update spacy/tokens/_dict_proxies.py * Add more SpanGroups serialization tests Test that serialized SpanGroups maintain their Span order * small clarification on older spaCy version * Update spacy/tests/serialize/test_serialize_span_groups.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
7e13652d36
commit
6c6b8da7cc
|
@ -204,6 +204,11 @@ class Warnings(metaclass=ErrorsWithCodes):
|
|||
"for the corpora used to train the language. Please check "
|
||||
"`nlp.meta[\"sources\"]` for any relevant links.")
|
||||
W119 = ("Overriding pipe name in `config` is not supported. Ignoring override '{name_in_config}'.")
|
||||
W120 = ("Unable to load all spans in Doc.spans: more than one span group "
|
||||
"with the name '{group_name}' was found in the saved spans data. "
|
||||
"Only the last span group will be loaded under "
|
||||
"Doc.spans['{group_name}']. Skipping span group with values: "
|
||||
"{group_values}")
|
||||
|
||||
|
||||
class Errors(metaclass=ErrorsWithCodes):
|
||||
|
|
161
spacy/tests/serialize/test_serialize_span_groups.py
Normal file
161
spacy/tests/serialize/test_serialize_span_groups.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
import pytest
|
||||
|
||||
from spacy.tokens import Span, SpanGroup
|
||||
from spacy.tokens._dict_proxies import SpanGroups
|
||||
|
||||
|
||||
@pytest.mark.issue(10685)
|
||||
def test_issue10685(en_tokenizer):
|
||||
"""Test `SpanGroups` de/serialization"""
|
||||
# Start with a Doc with no SpanGroups
|
||||
doc = en_tokenizer("Will it blend?")
|
||||
|
||||
# Test empty `SpanGroups` de/serialization:
|
||||
assert len(doc.spans) == 0
|
||||
doc.spans.from_bytes(doc.spans.to_bytes())
|
||||
assert len(doc.spans) == 0
|
||||
|
||||
# Test non-empty `SpanGroups` de/serialization:
|
||||
doc.spans["test"] = SpanGroup(doc, name="test", spans=[doc[0:1]])
|
||||
doc.spans["test2"] = SpanGroup(doc, name="test", spans=[doc[1:2]])
|
||||
|
||||
def assert_spangroups():
|
||||
assert len(doc.spans) == 2
|
||||
assert doc.spans["test"].name == "test"
|
||||
assert doc.spans["test2"].name == "test"
|
||||
assert list(doc.spans["test"]) == [doc[0:1]]
|
||||
assert list(doc.spans["test2"]) == [doc[1:2]]
|
||||
|
||||
# Sanity check the currently-expected behavior
|
||||
assert_spangroups()
|
||||
|
||||
# Now test serialization/deserialization:
|
||||
doc.spans.from_bytes(doc.spans.to_bytes())
|
||||
|
||||
assert_spangroups()
|
||||
|
||||
|
||||
def test_span_groups_serialization_mismatches(en_tokenizer):
|
||||
"""Test the serialization of multiple mismatching `SpanGroups` keys and `SpanGroup.name`s"""
|
||||
doc = en_tokenizer("How now, brown cow?")
|
||||
# Some variety:
|
||||
# 1 SpanGroup where its name matches its key
|
||||
# 2 SpanGroups that have the same name--which is not a key
|
||||
# 2 SpanGroups that have the same name--which is a key
|
||||
# 1 SpanGroup that is a value for 2 different keys (where its name is a key)
|
||||
# 1 SpanGroup that is a value for 2 different keys (where its name is not a key)
|
||||
groups = doc.spans
|
||||
groups["key1"] = SpanGroup(doc, name="key1", spans=[doc[0:1], doc[1:2]])
|
||||
groups["key2"] = SpanGroup(doc, name="too", spans=[doc[3:4], doc[4:5]])
|
||||
groups["key3"] = SpanGroup(doc, name="too", spans=[doc[1:2], doc[0:1]])
|
||||
groups["key4"] = SpanGroup(doc, name="key4", spans=[doc[0:1]])
|
||||
groups["key5"] = SpanGroup(doc, name="key4", spans=[doc[0:1]])
|
||||
sg6 = SpanGroup(doc, name="key6", spans=[doc[0:1]])
|
||||
groups["key6"] = sg6
|
||||
groups["key7"] = sg6
|
||||
sg8 = SpanGroup(doc, name="also", spans=[doc[1:2]])
|
||||
groups["key8"] = sg8
|
||||
groups["key9"] = sg8
|
||||
|
||||
regroups = SpanGroups(doc).from_bytes(groups.to_bytes())
|
||||
|
||||
# Assert regroups == groups
|
||||
assert regroups.keys() == groups.keys()
|
||||
for key, regroup in regroups.items():
|
||||
# Assert regroup == groups[key]
|
||||
assert regroup.name == groups[key].name
|
||||
assert list(regroup) == list(groups[key])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spans_bytes,doc_text,expected_spangroups,expected_warning",
|
||||
# The bytestrings below were generated from an earlier version of spaCy
|
||||
# that serialized `SpanGroups` as a list of SpanGroup bytes (via SpanGroups.to_bytes).
|
||||
# Comments preceding the bytestrings indicate from what Doc they were created.
|
||||
[
|
||||
# Empty SpanGroups:
|
||||
(b"\x90", "", {}, False),
|
||||
# doc = nlp("Will it blend?")
|
||||
# doc.spans['test'] = SpanGroup(doc, name='test', spans=[doc[0:1]])
|
||||
(
|
||||
b"\x91\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04",
|
||||
"Will it blend?",
|
||||
{"test": {"name": "test", "spans": [(0, 1)]}},
|
||||
False,
|
||||
),
|
||||
# doc = nlp("Will it blend?")
|
||||
# doc.spans['test'] = SpanGroup(doc, name='test', spans=[doc[0:1]])
|
||||
# doc.spans['test2'] = SpanGroup(doc, name='test', spans=[doc[1:2]])
|
||||
(
|
||||
b"\x92\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x07",
|
||||
"Will it blend?",
|
||||
# We expect only 1 SpanGroup to be in doc.spans in this example
|
||||
# because there are 2 `SpanGroup`s that have the same .name. See #10685.
|
||||
{"test": {"name": "test", "spans": [(1, 2)]}},
|
||||
True,
|
||||
),
|
||||
# doc = nlp('How now, brown cow?')
|
||||
# doc.spans['key1'] = SpanGroup(doc, name='key1', spans=[doc[0:1], doc[1:2]])
|
||||
# doc.spans['key2'] = SpanGroup(doc, name='too', spans=[doc[3:4], doc[4:5]])
|
||||
# doc.spans['key3'] = SpanGroup(doc, name='too', spans=[doc[1:2], doc[0:1]])
|
||||
# doc.spans['key4'] = SpanGroup(doc, name='key4', spans=[doc[0:1]])
|
||||
# doc.spans['key5'] = SpanGroup(doc, name='key4', spans=[doc[0:1]])
|
||||
(
|
||||
b"\x95\xc4m\x83\xa4name\xa4key1\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x07\xc4l\x83\xa4name\xa3too\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\t\x00\x00\x00\x0e\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x05\x00\x00\x00\x0f\x00\x00\x00\x12\xc4l\x83\xa4name\xa3too\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x07\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4C\x83\xa4name\xa4key4\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4C\x83\xa4name\xa4key4\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03",
|
||||
"How now, brown cow?",
|
||||
{
|
||||
"key1": {"name": "key1", "spans": [(0, 1), (1, 2)]},
|
||||
"too": {"name": "too", "spans": [(1, 2), (0, 1)]},
|
||||
"key4": {"name": "key4", "spans": [(0, 1)]},
|
||||
},
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_deserialize_span_groups_compat(
|
||||
en_tokenizer, spans_bytes, doc_text, expected_spangroups, expected_warning
|
||||
):
|
||||
"""Test backwards-compatibility of `SpanGroups` deserialization.
|
||||
This uses serializations (bytes) from a prior version of spaCy (before 3.3.1).
|
||||
|
||||
spans_bytes (bytes): Serialized `SpanGroups` object.
|
||||
doc_text (str): Doc text.
|
||||
expected_spangroups (dict):
|
||||
Dict mapping every expected (after deserialization) `SpanGroups` key
|
||||
to a SpanGroup's "args", where a SpanGroup's args are given as a dict:
|
||||
{"name": span_group.name,
|
||||
"spans": [(span0.start, span0.end), ...]}
|
||||
expected_warning (bool): Whether a warning is to be expected from .from_bytes()
|
||||
--i.e. if more than 1 SpanGroup has the same .name within the `SpanGroups`.
|
||||
"""
|
||||
doc = en_tokenizer(doc_text)
|
||||
|
||||
if expected_warning:
|
||||
with pytest.warns(UserWarning):
|
||||
doc.spans.from_bytes(spans_bytes)
|
||||
else:
|
||||
# TODO: explicitly check for lack of a warning
|
||||
doc.spans.from_bytes(spans_bytes)
|
||||
|
||||
assert doc.spans.keys() == expected_spangroups.keys()
|
||||
for name, spangroup_args in expected_spangroups.items():
|
||||
assert doc.spans[name].name == spangroup_args["name"]
|
||||
spans = [Span(doc, start, end) for start, end in spangroup_args["spans"]]
|
||||
assert list(doc.spans[name]) == spans
|
||||
|
||||
|
||||
def test_span_groups_serialization(en_tokenizer):
|
||||
doc = en_tokenizer("0 1 2 3 4 5 6")
|
||||
span_groups = SpanGroups(doc)
|
||||
spans = [doc[0:2], doc[1:3]]
|
||||
sg1 = SpanGroup(doc, spans=spans)
|
||||
span_groups["key1"] = sg1
|
||||
span_groups["key2"] = sg1
|
||||
span_groups["key3"] = []
|
||||
reloaded_span_groups = SpanGroups(doc).from_bytes(span_groups.to_bytes())
|
||||
assert span_groups.keys() == reloaded_span_groups.keys()
|
||||
for key, value in span_groups.items():
|
||||
assert all(
|
||||
span == reloaded_span
|
||||
for span, reloaded_span in zip(span_groups[key], reloaded_span_groups[key])
|
||||
)
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Iterable, Tuple, Union, Optional, TYPE_CHECKING
|
||||
from typing import Dict, Iterable, List, Tuple, Union, Optional, TYPE_CHECKING
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import UserDict
|
||||
import srsly
|
||||
|
||||
from .span_group import SpanGroup
|
||||
from ..errors import Errors
|
||||
from ..errors import Errors, Warnings
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -16,7 +17,7 @@ if TYPE_CHECKING:
|
|||
# Why inherit from UserDict instead of dict here?
|
||||
# Well, the 'dict' class doesn't necessarily delegate everything nicely,
|
||||
# for performance reasons. The UserDict is slower but better behaved.
|
||||
# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww
|
||||
# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/
|
||||
class SpanGroups(UserDict):
|
||||
"""A dict-like proxy held by the Doc, to control access to span groups."""
|
||||
|
||||
|
@ -53,20 +54,52 @@ class SpanGroups(UserDict):
|
|||
return super().setdefault(key, default=default)
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
# We don't need to serialize this as a dict, because the groups
|
||||
# know their names.
|
||||
# We serialize this as a dict in order to track the key(s) a SpanGroup
|
||||
# is a value of (in a backward- and forward-compatible way), since
|
||||
# a SpanGroup can have a key that doesn't match its `.name` (See #10685)
|
||||
if len(self) == 0:
|
||||
return self._EMPTY_BYTES
|
||||
msg = [value.to_bytes() for value in self.values()]
|
||||
msg: Dict[bytes, List[str]] = {}
|
||||
for key, value in self.items():
|
||||
msg.setdefault(value.to_bytes(), []).append(key)
|
||||
return srsly.msgpack_dumps(msg)
|
||||
|
||||
def from_bytes(self, bytes_data: bytes) -> "SpanGroups":
|
||||
msg = [] if bytes_data == self._EMPTY_BYTES else srsly.msgpack_loads(bytes_data)
|
||||
# backwards-compatibility: bytes_data may be one of:
|
||||
# b'', a serialized empty list, a serialized list of SpanGroup bytes
|
||||
# or a serialized dict of SpanGroup bytes -> keys
|
||||
msg = (
|
||||
[]
|
||||
if not bytes_data or bytes_data == self._EMPTY_BYTES
|
||||
else srsly.msgpack_loads(bytes_data)
|
||||
)
|
||||
self.clear()
|
||||
doc = self._ensure_doc()
|
||||
if isinstance(msg, list):
|
||||
# This is either the 1st version of `SpanGroups` serialization
|
||||
# or there were no SpanGroups serialized
|
||||
for value_bytes in msg:
|
||||
group = SpanGroup(doc).from_bytes(value_bytes)
|
||||
if group.name in self:
|
||||
# Display a warning if `msg` contains `SpanGroup`s
|
||||
# that have the same .name (attribute).
|
||||
# Because, for `SpanGroups` serialized as lists,
|
||||
# only 1 SpanGroup per .name is loaded. (See #10685)
|
||||
warnings.warn(
|
||||
Warnings.W120.format(
|
||||
group_name=group.name, group_values=self[group.name]
|
||||
)
|
||||
)
|
||||
self[group.name] = group
|
||||
else:
|
||||
for value_bytes, keys in msg.items():
|
||||
group = SpanGroup(doc).from_bytes(value_bytes)
|
||||
# Deserialize `SpanGroup`s as copies because it's possible for two
|
||||
# different `SpanGroup`s (pre-serialization) to have the same bytes
|
||||
# (since they can have the same `.name`).
|
||||
self[keys[0]] = group
|
||||
for key in keys[1:]:
|
||||
self[key] = group.copy()
|
||||
return self
|
||||
|
||||
def _ensure_doc(self) -> "Doc":
|
||||
|
|
|
@ -24,3 +24,4 @@ class SpanGroup:
|
|||
def __getitem__(self, i: int) -> Span: ...
|
||||
def to_bytes(self) -> bytes: ...
|
||||
def from_bytes(self, bytes_data: bytes) -> SpanGroup: ...
|
||||
def copy(self) -> SpanGroup: ...
|
||||
|
|
Loading…
Reference in New Issue
Block a user