DocPallet -> DocBin

This commit is contained in:
Matthew Honnibal 2019-09-18 15:15:37 +02:00
parent f537cbeacc
commit e53b86751f
2 changed files with 23 additions and 23 deletions

View File

@ -4,6 +4,6 @@ from __future__ import unicode_literals
from .doc import Doc from .doc import Doc
from .token import Token from .token import Token
from .span import Span from .span import Span
from ._serialize import DocPallet from ._serialize import DocBin
__all__ = ["Doc", "Token", "Span", "DocPallet"] __all__ = ["Doc", "Token", "Span", "DocBin"]

View File

@ -11,13 +11,13 @@ from ..tokens import Doc
from ..attrs import SPACY, ORTH from ..attrs import SPACY, ORTH
class DocPallet(object): class DocBin(object):
"""Pack Doc objects for export. """Pack Doc objects for binary serialization.
The DocPallet class lets you efficiently serialize the information from a The DocBin class lets you efficiently serialize the information from a
collection of Doc objects. You can control which information is serialized collection of Doc objects. You can control which information is serialized
by passing a list of attribute IDs, and optionally also specify whether the by passing a list of attribute IDs, and optionally also specify whether the
user data is serialized. The DocPallet is faster and produces smaller data user data is serialized. The DocBin is faster and produces smaller data
sizes than pickle, and allows you to deserialize without executing arbitrary sizes than pickle, and allows you to deserialize without executing arbitrary
Python code. Python code.
@ -41,7 +41,7 @@ class DocPallet(object):
document from the pallet. document from the pallet.
""" """
def __init__(self, attrs=None, store_user_data=False): def __init__(self, attrs=None, store_user_data=False):
"""Create a DocPallet object, to hold serialized annotations. """Create a DocBin object, to hold serialized annotations.
attrs (list): List of attributes to serialize. 'orth' and 'spacy' are attrs (list): List of attributes to serialize. 'orth' and 'spacy' are
always serialized, so they're not required. Defaults to None. always serialized, so they're not required. Defaults to None.
@ -57,7 +57,7 @@ class DocPallet(object):
self.store_user_data = store_user_data self.store_user_data = store_user_data
def add(self, doc): def add(self, doc):
"""Add a doc's annotations to the DocPallet for serialization.""" """Add a doc's annotations to the DocBin for serialization."""
array = doc.to_array(self.attrs) array = doc.to_array(self.attrs)
if len(array.shape) == 1: if len(array.shape) == 1:
array = array.reshape((array.shape[0], 1)) array = array.reshape((array.shape[0], 1))
@ -86,7 +86,7 @@ class DocPallet(object):
yield doc yield doc
def merge(self, other): def merge(self, other):
"""Extend the annotations of this DocPallet with the annotations from another.""" """Extend the annotations of this DocBin with the annotations from another."""
assert self.attrs == other.attrs assert self.attrs == other.attrs
self.tokens.extend(other.tokens) self.tokens.extend(other.tokens)
self.spaces.extend(other.spaces) self.spaces.extend(other.spaces)
@ -95,7 +95,7 @@ class DocPallet(object):
self.user_data.extend(other.user_data) self.user_data.extend(other.user_data)
def to_bytes(self): def to_bytes(self):
"""Serialize the DocPallet's annotations into a byte string.""" """Serialize the DocBin's annotations into a byte string."""
for tokens in self.tokens: for tokens in self.tokens:
assert len(tokens.shape) == 2, tokens.shape assert len(tokens.shape) == 2, tokens.shape
lengths = [len(tokens) for tokens in self.tokens] lengths = [len(tokens) for tokens in self.tokens]
@ -111,7 +111,7 @@ class DocPallet(object):
return gzip.compress(srsly.msgpack_dumps(msg)) return gzip.compress(srsly.msgpack_dumps(msg))
def from_bytes(self, string): def from_bytes(self, string):
"""Deserialize the DocPallet's annotations from a byte string.""" """Deserialize the DocBin's annotations from a byte string."""
msg = srsly.msgpack_loads(gzip.decompress(string)) msg = srsly.msgpack_loads(gzip.decompress(string))
self.attrs = msg["attrs"] self.attrs = msg["attrs"]
self.strings = set(msg["strings"]) self.strings = set(msg["strings"])
@ -130,31 +130,31 @@ class DocPallet(object):
return self return self
def merge_boxes(boxes): def merge_bins(bins):
merged = None merged = None
for byte_string in boxes: for byte_string in bins:
if byte_string is not None: if byte_string is not None:
box = DocPallet(store_user_data=True).from_bytes(byte_string) doc_bin = DocBin(store_user_data=True).from_bytes(byte_string)
if merged is None: if merged is None:
merged = box merged = doc_bin
else: else:
merged.merge(box) merged.merge(doc_bin)
if merged is not None: if merged is not None:
return merged.to_bytes() return merged.to_bytes()
else: else:
return b"" return b""
def pickle_box(box): def pickle_bin(docbin):
return (unpickle_box, (box.to_bytes(),)) return (unpickle_bin, (bin_.to_bytes(),))
def unpickle_box(byte_string): def unpickle_bin(byte_string):
return DocPallet().from_bytes(byte_string) return DocBin().from_bytes(byte_string)
copy_reg.pickle(DocPallet, pickle_box, unpickle_box) copy_reg.pickle(DocBin, pickle_bin, unpickle_bin)
# Compatibility, as we had named it this previously. # Compatibility, as we had named it this previously.
Binder = DocPallet Binder = DocBin
__all__ = ["DocPallet"] __all__ = ["DocBin"]