mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	Fix: De/Serialize SpanGroups including the SpanGroup keys (#10707)
				
					
				
			* fix: De/Serialize `SpanGroups` including the SpanGroup keys This prevents the loss of `SpanGroup`s that have the same .name as other `SpanGroup`s within the same `SpanGroups` object (upon de/serialization of the `SpanGroups`). Fixes #10685 * Maintain backwards compatibility for serialized `SpanGroups` (serialized as: a list of `SpanGroup`s, or b'') * Add tests for `SpanGroups` deserialization backwards-compatibility * Move a `SpanGroups` de/serialization test (test_issue10685) to tests/serialize/test_serialize_spangroups.py * Output a warning if deserializing a `SpanGroups` with duplicate .name-d `SpanGroup`s * Minor refactor * `SpanGroups.from_bytes` handles only `list` and `dict` types with `dict` as the expected default * For lists, keep first rather than last value encountered * Update error message * Rename and update tests * Update to preserve list serialization of SpanGroups To avoid breaking compatibility of serialized `Doc` and `DocBin` with earlier versions of spacy v3, revert back to a list-only serialization, but update the names just for serialization so that the SpanGroups keys override the SpanGroup names. * Preserve object identity and current key overwrite * Preserve SpanGroup object identity * Preserve last rather than first span group from SpanGroup list format without SpanGroups keys * Update inline comments * Fix types * Add type info for SpanGroup.copy * Deserialize `SpanGroup`s as copies when a single SpanGroup is the value for more than 1 `SpanGroups` key. This is because we serialize `SpanGroups` as dicts (to maintain backward- and forward-compatibility) and we can't assume `SpanGroup`s with the same bytes/serialization were the same (identical) object, pre-serialization. * Update spacy/tokens/_dict_proxies.py * Add more SpanGroups serialization tests Test that serialized SpanGroups maintain their Span order * small clarification on older spaCy version * Update spacy/tests/serialize/test_serialize_span_groups.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									7fe0594898
								
							
						
					
					
						commit
						7c0d582224
					
				|  | @ -200,6 +200,11 @@ class Warnings(metaclass=ErrorsWithCodes): | |||
|             "for the corpora used to train the language. Please check " | ||||
|             "`nlp.meta[\"sources\"]` for any relevant links.") | ||||
|     W119 = ("Overriding pipe name in `config` is not supported. Ignoring override '{name_in_config}'.") | ||||
|     W120 = ("Unable to load all spans in Doc.spans: more than one span group " | ||||
|             "with the name '{group_name}' was found in the saved spans data. " | ||||
|             "Only the last span group will be loaded under " | ||||
|             "Doc.spans['{group_name}']. Skipping span group with values: " | ||||
|             "{group_values}") | ||||
| 
 | ||||
| 
 | ||||
| class Errors(metaclass=ErrorsWithCodes): | ||||
|  |  | |||
							
								
								
									
										161
									
								
								spacy/tests/serialize/test_serialize_span_groups.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								spacy/tests/serialize/test_serialize_span_groups.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,161 @@ | |||
| import pytest | ||||
| 
 | ||||
| from spacy.tokens import Span, SpanGroup | ||||
| from spacy.tokens._dict_proxies import SpanGroups | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.issue(10685) | ||||
| def test_issue10685(en_tokenizer): | ||||
|     """Test `SpanGroups` de/serialization""" | ||||
|     # Start with a Doc with no SpanGroups | ||||
|     doc = en_tokenizer("Will it blend?") | ||||
| 
 | ||||
|     # Test empty `SpanGroups` de/serialization: | ||||
|     assert len(doc.spans) == 0 | ||||
|     doc.spans.from_bytes(doc.spans.to_bytes()) | ||||
|     assert len(doc.spans) == 0 | ||||
| 
 | ||||
|     # Test non-empty `SpanGroups` de/serialization: | ||||
|     doc.spans["test"] = SpanGroup(doc, name="test", spans=[doc[0:1]]) | ||||
|     doc.spans["test2"] = SpanGroup(doc, name="test", spans=[doc[1:2]]) | ||||
| 
 | ||||
|     def assert_spangroups(): | ||||
|         assert len(doc.spans) == 2 | ||||
|         assert doc.spans["test"].name == "test" | ||||
|         assert doc.spans["test2"].name == "test" | ||||
|         assert list(doc.spans["test"]) == [doc[0:1]] | ||||
|         assert list(doc.spans["test2"]) == [doc[1:2]] | ||||
| 
 | ||||
|     # Sanity check the currently-expected behavior | ||||
|     assert_spangroups() | ||||
| 
 | ||||
|     # Now test serialization/deserialization: | ||||
|     doc.spans.from_bytes(doc.spans.to_bytes()) | ||||
| 
 | ||||
|     assert_spangroups() | ||||
| 
 | ||||
| 
 | ||||
| def test_span_groups_serialization_mismatches(en_tokenizer): | ||||
|     """Test the serialization of multiple mismatching `SpanGroups` keys and `SpanGroup.name`s""" | ||||
|     doc = en_tokenizer("How now, brown cow?") | ||||
|     # Some variety: | ||||
|     # 1 SpanGroup where its name matches its key | ||||
|     # 2 SpanGroups that have the same name--which is not a key | ||||
|     # 2 SpanGroups that have the same name--which is a key | ||||
|     # 1 SpanGroup that is a value for 2 different keys (where its name is a key) | ||||
|     # 1 SpanGroup that is a value for 2 different keys (where its name is not a key) | ||||
|     groups = doc.spans | ||||
|     groups["key1"] = SpanGroup(doc, name="key1", spans=[doc[0:1], doc[1:2]]) | ||||
|     groups["key2"] = SpanGroup(doc, name="too", spans=[doc[3:4], doc[4:5]]) | ||||
|     groups["key3"] = SpanGroup(doc, name="too", spans=[doc[1:2], doc[0:1]]) | ||||
|     groups["key4"] = SpanGroup(doc, name="key4", spans=[doc[0:1]]) | ||||
|     groups["key5"] = SpanGroup(doc, name="key4", spans=[doc[0:1]]) | ||||
|     sg6 = SpanGroup(doc, name="key6", spans=[doc[0:1]]) | ||||
|     groups["key6"] = sg6 | ||||
|     groups["key7"] = sg6 | ||||
|     sg8 = SpanGroup(doc, name="also", spans=[doc[1:2]]) | ||||
|     groups["key8"] = sg8 | ||||
|     groups["key9"] = sg8 | ||||
| 
 | ||||
|     regroups = SpanGroups(doc).from_bytes(groups.to_bytes()) | ||||
| 
 | ||||
|     # Assert regroups == groups | ||||
|     assert regroups.keys() == groups.keys() | ||||
|     for key, regroup in regroups.items(): | ||||
|         # Assert regroup == groups[key] | ||||
|         assert regroup.name == groups[key].name | ||||
|         assert list(regroup) == list(groups[key]) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "spans_bytes,doc_text,expected_spangroups,expected_warning", | ||||
|     # The bytestrings below were generated from an earlier version of spaCy | ||||
|     # that serialized `SpanGroups` as a list of SpanGroup bytes (via SpanGroups.to_bytes). | ||||
|     # Comments preceding the bytestrings indicate from what Doc they were created. | ||||
|     [ | ||||
|         # Empty SpanGroups: | ||||
|         (b"\x90", "", {}, False), | ||||
|         # doc = nlp("Will it blend?") | ||||
|         # doc.spans['test'] = SpanGroup(doc, name='test', spans=[doc[0:1]]) | ||||
|         ( | ||||
|             b"\x91\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04", | ||||
|             "Will it blend?", | ||||
|             {"test": {"name": "test", "spans": [(0, 1)]}}, | ||||
|             False, | ||||
|         ), | ||||
|         # doc = nlp("Will it blend?") | ||||
|         # doc.spans['test']  = SpanGroup(doc, name='test', spans=[doc[0:1]]) | ||||
|         # doc.spans['test2'] = SpanGroup(doc, name='test', spans=[doc[1:2]]) | ||||
|         ( | ||||
|             b"\x92\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x04\xc4C\x83\xa4name\xa4test\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x07", | ||||
|             "Will it blend?", | ||||
|             # We expect only 1 SpanGroup to be in doc.spans in this example | ||||
|             # because there are 2 `SpanGroup`s that have the same .name. See #10685. | ||||
|             {"test": {"name": "test", "spans": [(1, 2)]}}, | ||||
|             True, | ||||
|         ), | ||||
|         # doc = nlp('How now, brown cow?') | ||||
|         # doc.spans['key1'] = SpanGroup(doc, name='key1', spans=[doc[0:1], doc[1:2]]) | ||||
|         # doc.spans['key2'] = SpanGroup(doc, name='too', spans=[doc[3:4], doc[4:5]]) | ||||
|         # doc.spans['key3'] = SpanGroup(doc, name='too', spans=[doc[1:2], doc[0:1]]) | ||||
|         # doc.spans['key4'] = SpanGroup(doc, name='key4', spans=[doc[0:1]]) | ||||
|         # doc.spans['key5'] = SpanGroup(doc, name='key4', spans=[doc[0:1]]) | ||||
|         ( | ||||
|             b"\x95\xc4m\x83\xa4name\xa4key1\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x07\xc4l\x83\xa4name\xa3too\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\t\x00\x00\x00\x0e\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x05\x00\x00\x00\x0f\x00\x00\x00\x12\xc4l\x83\xa4name\xa3too\xa5attrs\x80\xa5spans\x92\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x07\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4C\x83\xa4name\xa4key4\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\xc4C\x83\xa4name\xa4key4\xa5attrs\x80\xa5spans\x91\xc4(\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03", | ||||
|             "How now, brown cow?", | ||||
|             { | ||||
|                 "key1": {"name": "key1", "spans": [(0, 1), (1, 2)]}, | ||||
|                 "too": {"name": "too", "spans": [(1, 2), (0, 1)]}, | ||||
|                 "key4": {"name": "key4", "spans": [(0, 1)]}, | ||||
|             }, | ||||
|             True, | ||||
|         ), | ||||
|     ], | ||||
| ) | ||||
| def test_deserialize_span_groups_compat( | ||||
|     en_tokenizer, spans_bytes, doc_text, expected_spangroups, expected_warning | ||||
| ): | ||||
|     """Test backwards-compatibility of `SpanGroups` deserialization. | ||||
|     This uses serializations (bytes) from a prior version of spaCy (before 3.3.1). | ||||
| 
 | ||||
|     spans_bytes (bytes): Serialized `SpanGroups` object. | ||||
|     doc_text (str): Doc text. | ||||
|     expected_spangroups (dict): | ||||
|         Dict mapping every expected (after deserialization) `SpanGroups` key | ||||
|         to a SpanGroup's "args", where a SpanGroup's args are given as a dict: | ||||
|           {"name": span_group.name, | ||||
|            "spans": [(span0.start, span0.end), ...]} | ||||
|     expected_warning (bool): Whether a warning is to be expected from .from_bytes() | ||||
|         --i.e. if more than 1 SpanGroup has the same .name within the `SpanGroups`. | ||||
|     """ | ||||
|     doc = en_tokenizer(doc_text) | ||||
| 
 | ||||
|     if expected_warning: | ||||
|         with pytest.warns(UserWarning): | ||||
|             doc.spans.from_bytes(spans_bytes) | ||||
|     else: | ||||
|         # TODO: explicitly check for lack of a warning | ||||
|         doc.spans.from_bytes(spans_bytes) | ||||
| 
 | ||||
|     assert doc.spans.keys() == expected_spangroups.keys() | ||||
|     for name, spangroup_args in expected_spangroups.items(): | ||||
|         assert doc.spans[name].name == spangroup_args["name"] | ||||
|         spans = [Span(doc, start, end) for start, end in spangroup_args["spans"]] | ||||
|         assert list(doc.spans[name]) == spans | ||||
| 
 | ||||
| 
 | ||||
| def test_span_groups_serialization(en_tokenizer): | ||||
|     doc = en_tokenizer("0 1 2 3 4 5 6") | ||||
|     span_groups = SpanGroups(doc) | ||||
|     spans = [doc[0:2], doc[1:3]] | ||||
|     sg1 = SpanGroup(doc, spans=spans) | ||||
|     span_groups["key1"] = sg1 | ||||
|     span_groups["key2"] = sg1 | ||||
|     span_groups["key3"] = [] | ||||
|     reloaded_span_groups = SpanGroups(doc).from_bytes(span_groups.to_bytes()) | ||||
|     assert span_groups.keys() == reloaded_span_groups.keys() | ||||
|     for key, value in span_groups.items(): | ||||
|         assert all( | ||||
|             span == reloaded_span | ||||
|             for span, reloaded_span in zip(span_groups[key], reloaded_span_groups[key]) | ||||
|         ) | ||||
|  | @ -1,10 +1,11 @@ | |||
| from typing import Iterable, Tuple, Union, Optional, TYPE_CHECKING | ||||
| from typing import Dict, Iterable, List, Tuple, Union, Optional, TYPE_CHECKING | ||||
| import warnings | ||||
| import weakref | ||||
| from collections import UserDict | ||||
| import srsly | ||||
| 
 | ||||
| from .span_group import SpanGroup | ||||
| from ..errors import Errors | ||||
| from ..errors import Errors, Warnings | ||||
| 
 | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|  | @ -16,7 +17,7 @@ if TYPE_CHECKING: | |||
| # Why inherit from UserDict instead of dict here? | ||||
| # Well, the 'dict' class doesn't necessarily delegate everything nicely, | ||||
| # for performance reasons. The UserDict is slower but better behaved. | ||||
| # See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww | ||||
| # See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/ | ||||
| class SpanGroups(UserDict): | ||||
|     """A dict-like proxy held by the Doc, to control access to span groups.""" | ||||
| 
 | ||||
|  | @ -53,20 +54,52 @@ class SpanGroups(UserDict): | |||
|         return super().setdefault(key, default=default) | ||||
| 
 | ||||
|     def to_bytes(self) -> bytes: | ||||
|         # We don't need to serialize this as a dict, because the groups | ||||
|         # know their names. | ||||
|         # We serialize this as a dict in order to track the key(s) a SpanGroup | ||||
|         # is a value of (in a backward- and forward-compatible way), since | ||||
|         # a SpanGroup can have a key that doesn't match its `.name` (See #10685) | ||||
|         if len(self) == 0: | ||||
|             return self._EMPTY_BYTES | ||||
|         msg = [value.to_bytes() for value in self.values()] | ||||
|         msg: Dict[bytes, List[str]] = {} | ||||
|         for key, value in self.items(): | ||||
|             msg.setdefault(value.to_bytes(), []).append(key) | ||||
|         return srsly.msgpack_dumps(msg) | ||||
| 
 | ||||
|     def from_bytes(self, bytes_data: bytes) -> "SpanGroups": | ||||
|         msg = [] if bytes_data == self._EMPTY_BYTES else srsly.msgpack_loads(bytes_data) | ||||
|         # backwards-compatibility: bytes_data may be one of: | ||||
|         # b'', a serialized empty list, a serialized list of SpanGroup bytes | ||||
|         # or a serialized dict of SpanGroup bytes -> keys | ||||
|         msg = ( | ||||
|             [] | ||||
|             if not bytes_data or bytes_data == self._EMPTY_BYTES | ||||
|             else srsly.msgpack_loads(bytes_data) | ||||
|         ) | ||||
|         self.clear() | ||||
|         doc = self._ensure_doc() | ||||
|         for value_bytes in msg: | ||||
|             group = SpanGroup(doc).from_bytes(value_bytes) | ||||
|             self[group.name] = group | ||||
|         if isinstance(msg, list): | ||||
|             # This is either the 1st version of `SpanGroups` serialization | ||||
|             # or there were no SpanGroups serialized | ||||
|             for value_bytes in msg: | ||||
|                 group = SpanGroup(doc).from_bytes(value_bytes) | ||||
|                 if group.name in self: | ||||
|                     # Display a warning if `msg` contains `SpanGroup`s | ||||
|                     # that have the same .name (attribute). | ||||
|                     # Because, for `SpanGroups` serialized as lists, | ||||
|                     # only 1 SpanGroup per .name is loaded. (See #10685) | ||||
|                     warnings.warn( | ||||
|                         Warnings.W120.format( | ||||
|                             group_name=group.name, group_values=self[group.name] | ||||
|                         ) | ||||
|                     ) | ||||
|                 self[group.name] = group | ||||
|         else: | ||||
|             for value_bytes, keys in msg.items(): | ||||
|                 group = SpanGroup(doc).from_bytes(value_bytes) | ||||
|                 # Deserialize `SpanGroup`s as copies because it's possible for two | ||||
|                 # different `SpanGroup`s (pre-serialization) to have the same bytes | ||||
|                 # (since they can have the same `.name`). | ||||
|                 self[keys[0]] = group | ||||
|                 for key in keys[1:]: | ||||
|                     self[key] = group.copy() | ||||
|         return self | ||||
| 
 | ||||
|     def _ensure_doc(self) -> "Doc": | ||||
|  |  | |||
|  | @ -24,3 +24,4 @@ class SpanGroup: | |||
|     def __getitem__(self, i: int) -> Span: ... | ||||
|     def to_bytes(self) -> bytes: ... | ||||
|     def from_bytes(self, bytes_data: bytes) -> SpanGroup: ... | ||||
|     def copy(self) -> SpanGroup: ... | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user