Support to_dict in Doc

This commit is contained in:
Matthew Honnibal 2020-06-06 15:10:10 +02:00
parent 7b873ce2b1
commit 1d2e39d974

View File

@ -881,6 +881,32 @@ cdef class Doc:
def to_bytes(self, exclude=tuple(), **kwargs): def to_bytes(self, exclude=tuple(), **kwargs):
"""Serialize, i.e. export the document contents to a binary string. """Serialize, i.e. export the document contents to a binary string.
exclude (list): String names of serialization fields to exclude.
RETURNS (bytes): A losslessly serialized copy of the `Doc`, including
all annotations.
DOCS: https://spacy.io/api/doc#to_bytes
"""
return srsly.msgpack_dumps(self.to_dict(exclude=exclude, **kwargs))
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
"""Deserialize, i.e. import the document contents from a binary string.
data (bytes): The string to load from.
exclude (list): String names of serialization fields to exclude.
RETURNS (Doc): Itself.
DOCS: https://spacy.io/api/doc#from_bytes
"""
return self.from_dict(
srsly.msgpack_loads(bytes_data),
exclude=exclude,
**kwargs
)
def to_dict(self, exclude=tuple(), **kwargs):
"""Export the document contents to a dictionary for serialization.
exclude (list): String names of serialization fields to exclude. exclude (list): String names of serialization fields to exclude.
RETURNS (bytes): A losslessly serialized copy of the `Doc`, including RETURNS (bytes): A losslessly serialized copy of the `Doc`, including
all annotations. all annotations.
@ -917,9 +943,9 @@ cdef class Doc:
serializers["user_data_keys"] = lambda: srsly.msgpack_dumps(user_data_keys) serializers["user_data_keys"] = lambda: srsly.msgpack_dumps(user_data_keys)
if "user_data_values" not in exclude: if "user_data_values" not in exclude:
serializers["user_data_values"] = lambda: srsly.msgpack_dumps(user_data_values) serializers["user_data_values"] = lambda: srsly.msgpack_dumps(user_data_values)
return util.to_bytes(serializers, exclude) return util.to_dict(serializers, exclude)
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs): def from_dict(self, msg, exclude=tuple(), **kwargs):
"""Deserialize, i.e. import the document contents from a binary string. """Deserialize, i.e. import the document contents from a binary string.
data (bytes): The string to load from. data (bytes): The string to load from.
@ -943,7 +969,6 @@ cdef class Doc:
for key in kwargs: for key in kwargs:
if key in deserializers or key in ("user_data",): if key in deserializers or key in ("user_data",):
raise ValueError(Errors.E128.format(arg=key)) raise ValueError(Errors.E128.format(arg=key))
msg = util.from_bytes(bytes_data, deserializers, exclude)
# Msgpack doesn't distinguish between lists and tuples, which is # Msgpack doesn't distinguish between lists and tuples, which is
# vexing for user data. As a best guess, we *know* that within # vexing for user data. As a best guess, we *know* that within
# keys, we must have tuples. In values we just have to hope # keys, we must have tuples. In values we just have to hope
@ -975,6 +1000,7 @@ cdef class Doc:
self.from_array(msg["array_head"][2:], attrs[:, 2:]) self.from_array(msg["array_head"][2:], attrs[:, 2:])
return self return self
def extend_tensor(self, tensor): def extend_tensor(self, tensor):
"""Concatenate a new tensor onto the doc.tensor object. """Concatenate a new tensor onto the doc.tensor object.