Include Doc.cats in serialization of Doc and DocBin (#4774)

* Include Doc.cats in to_bytes()

* Include Doc.cats in DocBin serialization

* Add tests for serialization of cats

Test serialization of cats for Doc and DocBin.
This commit is contained in:
adrianeboyd 2019-12-06 14:07:39 +01:00 committed by Matthew Honnibal
parent e626a011cc
commit 676e75838f
3 changed files with 17 additions and 1 deletions

View File

@ -24,6 +24,7 @@ def test_serialize_empty_doc(en_vocab):
def test_serialize_doc_roundtrip_bytes(en_vocab): def test_serialize_doc_roundtrip_bytes(en_vocab):
doc = Doc(en_vocab, words=["hello", "world"]) doc = Doc(en_vocab, words=["hello", "world"])
doc.cats = {"A": 0.5}
doc_b = doc.to_bytes() doc_b = doc.to_bytes()
new_doc = Doc(en_vocab).from_bytes(doc_b) new_doc = Doc(en_vocab).from_bytes(doc_b)
assert new_doc.to_bytes() == doc_b assert new_doc.to_bytes() == doc_b
@ -66,12 +67,17 @@ def test_serialize_doc_exclude(en_vocab):
def test_serialize_doc_bin(): def test_serialize_doc_bin():
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True) doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
texts = ["Some text", "Lots of texts...", "..."] texts = ["Some text", "Lots of texts...", "..."]
cats = {"A": 0.5}
nlp = English() nlp = English()
for doc in nlp.pipe(texts): for doc in nlp.pipe(texts):
doc.cats = cats
doc_bin.add(doc) doc_bin.add(doc)
bytes_data = doc_bin.to_bytes() bytes_data = doc_bin.to_bytes()
# Deserialize later, e.g. in a new process # Deserialize later, e.g. in a new process
nlp = spacy.blank("en") nlp = spacy.blank("en")
doc_bin = DocBin().from_bytes(bytes_data) doc_bin = DocBin().from_bytes(bytes_data)
list(doc_bin.get_docs(nlp.vocab)) reloaded_docs = list(doc_bin.get_docs(nlp.vocab))
for i, doc in enumerate(reloaded_docs):
assert doc.text == texts[i]
assert doc.cats == cats

View File

@ -58,6 +58,7 @@ class DocBin(object):
self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0] self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0]
self.tokens = [] self.tokens = []
self.spaces = [] self.spaces = []
self.cats = []
self.user_data = [] self.user_data = []
self.strings = set() self.strings = set()
self.store_user_data = store_user_data self.store_user_data = store_user_data
@ -82,6 +83,7 @@ class DocBin(object):
spaces = spaces.reshape((spaces.shape[0], 1)) spaces = spaces.reshape((spaces.shape[0], 1))
self.spaces.append(numpy.asarray(spaces, dtype=bool)) self.spaces.append(numpy.asarray(spaces, dtype=bool))
self.strings.update(w.text for w in doc) self.strings.update(w.text for w in doc)
self.cats.append(doc.cats)
if self.store_user_data: if self.store_user_data:
self.user_data.append(srsly.msgpack_dumps(doc.user_data)) self.user_data.append(srsly.msgpack_dumps(doc.user_data))
@ -102,6 +104,7 @@ class DocBin(object):
words = [vocab.strings[orth] for orth in tokens[:, orth_col]] words = [vocab.strings[orth] for orth in tokens[:, orth_col]]
doc = Doc(vocab, words=words, spaces=spaces) doc = Doc(vocab, words=words, spaces=spaces)
doc = doc.from_array(self.attrs, tokens) doc = doc.from_array(self.attrs, tokens)
doc.cats = self.cats[i]
if self.store_user_data: if self.store_user_data:
user_data = srsly.msgpack_loads(self.user_data[i], use_list=False) user_data = srsly.msgpack_loads(self.user_data[i], use_list=False)
doc.user_data.update(user_data) doc.user_data.update(user_data)
@ -121,6 +124,7 @@ class DocBin(object):
self.tokens.extend(other.tokens) self.tokens.extend(other.tokens)
self.spaces.extend(other.spaces) self.spaces.extend(other.spaces)
self.strings.update(other.strings) self.strings.update(other.strings)
self.cats.extend(other.cats)
if self.store_user_data: if self.store_user_data:
self.user_data.extend(other.user_data) self.user_data.extend(other.user_data)
@ -140,6 +144,7 @@ class DocBin(object):
"spaces": numpy.vstack(self.spaces).tobytes("C"), "spaces": numpy.vstack(self.spaces).tobytes("C"),
"lengths": numpy.asarray(lengths, dtype="int32").tobytes("C"), "lengths": numpy.asarray(lengths, dtype="int32").tobytes("C"),
"strings": list(self.strings), "strings": list(self.strings),
"cats": self.cats,
} }
if self.store_user_data: if self.store_user_data:
msg["user_data"] = self.user_data msg["user_data"] = self.user_data
@ -164,6 +169,7 @@ class DocBin(object):
flat_spaces = flat_spaces.reshape((flat_spaces.size, 1)) flat_spaces = flat_spaces.reshape((flat_spaces.size, 1))
self.tokens = NumpyOps().unflatten(flat_tokens, lengths) self.tokens = NumpyOps().unflatten(flat_tokens, lengths)
self.spaces = NumpyOps().unflatten(flat_spaces, lengths) self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
self.cats = msg["cats"]
if self.store_user_data and "user_data" in msg: if self.store_user_data and "user_data" in msg:
self.user_data = list(msg["user_data"]) self.user_data = list(msg["user_data"])
for tokens in self.tokens: for tokens in self.tokens:

View File

@ -887,6 +887,7 @@ cdef class Doc:
"array_body": lambda: self.to_array(array_head), "array_body": lambda: self.to_array(array_head),
"sentiment": lambda: self.sentiment, "sentiment": lambda: self.sentiment,
"tensor": lambda: self.tensor, "tensor": lambda: self.tensor,
"cats": lambda: self.cats,
} }
for key in kwargs: for key in kwargs:
if key in serializers or key in ("user_data", "user_data_keys", "user_data_values"): if key in serializers or key in ("user_data", "user_data_keys", "user_data_values"):
@ -916,6 +917,7 @@ cdef class Doc:
"array_body": lambda b: None, "array_body": lambda b: None,
"sentiment": lambda b: None, "sentiment": lambda b: None,
"tensor": lambda b: None, "tensor": lambda b: None,
"cats": lambda b: None,
"user_data_keys": lambda b: None, "user_data_keys": lambda b: None,
"user_data_values": lambda b: None, "user_data_values": lambda b: None,
} }
@ -937,6 +939,8 @@ cdef class Doc:
self.sentiment = msg["sentiment"] self.sentiment = msg["sentiment"]
if "tensor" not in exclude and "tensor" in msg: if "tensor" not in exclude and "tensor" in msg:
self.tensor = msg["tensor"] self.tensor = msg["tensor"]
if "cats" not in exclude and "cats" in msg:
self.cats = msg["cats"]
start = 0 start = 0
cdef const LexemeC* lex cdef const LexemeC* lex
cdef unicode orth_ cdef unicode orth_