from typing import Iterable, Tuple, Union, Optional, TYPE_CHECKING import weakref from collections import UserDict import srsly from .span_group import SpanGroup from ..errors import Errors if TYPE_CHECKING: # This lets us add type hints for mypy etc. without causing circular imports from .doc import Doc # noqa: F401 from .span import Span # noqa: F401 # Why inherit from UserDict instead of dict here? # Well, the 'dict' class doesn't necessarily delegate everything nicely, # for performance reasons. The UserDict is slower but better behaved. # See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww class SpanGroups(UserDict): """A dict-like proxy held by the Doc, to control access to span groups.""" _EMPTY_BYTES = srsly.msgpack_dumps([]) def __init__( self, doc: "Doc", items: Iterable[Tuple[str, SpanGroup]] = tuple() ) -> None: self.doc_ref = weakref.ref(doc) UserDict.__init__(self, items) # type: ignore[arg-type] def __setitem__(self, key: str, value: Union[SpanGroup, Iterable["Span"]]) -> None: if not isinstance(value, SpanGroup): value = self._make_span_group(key, value) assert value.doc is self.doc_ref() UserDict.__setitem__(self, key, value) def _make_span_group(self, name: str, spans: Iterable["Span"]) -> SpanGroup: doc = self._ensure_doc() return SpanGroup(doc, name=name, spans=spans) def copy(self, doc: Optional["Doc"] = None) -> "SpanGroups": if doc is None: doc = self._ensure_doc() return SpanGroups(doc).from_bytes(self.to_bytes()) def setdefault(self, key, default=None): if not isinstance(default, SpanGroup): if default is None: spans = [] else: spans = default default = self._make_span_group(key, spans) return super().setdefault(key, default=default) def to_bytes(self) -> bytes: # We don't need to serialize this as a dict, because the groups # know their names. if len(self) == 0: return self._EMPTY_BYTES msg = [value.to_bytes() for value in self.values()] return srsly.msgpack_dumps(msg) def from_bytes(self, bytes_data: bytes) -> "SpanGroups": msg = [] if bytes_data == self._EMPTY_BYTES else srsly.msgpack_loads(bytes_data) self.clear() doc = self._ensure_doc() for value_bytes in msg: group = SpanGroup(doc).from_bytes(value_bytes) self[group.name] = group return self def _ensure_doc(self) -> "Doc": doc = self.doc_ref() if doc is None: raise ValueError(Errors.E866) return doc