diff --git a/spacy/tests/serialize/test_serialize_doc.py b/spacy/tests/serialize/test_serialize_doc.py index 87b087760..ef2b1ee89 100644 --- a/spacy/tests/serialize/test_serialize_doc.py +++ b/spacy/tests/serialize/test_serialize_doc.py @@ -24,6 +24,7 @@ def test_serialize_empty_doc(en_vocab): def test_serialize_doc_roundtrip_bytes(en_vocab): doc = Doc(en_vocab, words=["hello", "world"]) + doc.cats = {"A": 0.5} doc_b = doc.to_bytes() new_doc = Doc(en_vocab).from_bytes(doc_b) assert new_doc.to_bytes() == doc_b @@ -66,12 +67,17 @@ def test_serialize_doc_exclude(en_vocab): def test_serialize_doc_bin(): doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True) texts = ["Some text", "Lots of texts...", "..."] + cats = {"A": 0.5} nlp = English() for doc in nlp.pipe(texts): + doc.cats = cats doc_bin.add(doc) bytes_data = doc_bin.to_bytes() # Deserialize later, e.g. in a new process nlp = spacy.blank("en") doc_bin = DocBin().from_bytes(bytes_data) - list(doc_bin.get_docs(nlp.vocab)) + reloaded_docs = list(doc_bin.get_docs(nlp.vocab)) + for i, doc in enumerate(reloaded_docs): + assert doc.text == texts[i] + assert doc.cats == cats diff --git a/spacy/tokens/_serialize.py b/spacy/tokens/_serialize.py index 18cb8a234..b60a6d7b3 100644 --- a/spacy/tokens/_serialize.py +++ b/spacy/tokens/_serialize.py @@ -58,6 +58,7 @@ class DocBin(object): self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0] self.tokens = [] self.spaces = [] + self.cats = [] self.user_data = [] self.strings = set() self.store_user_data = store_user_data @@ -82,6 +83,7 @@ class DocBin(object): spaces = spaces.reshape((spaces.shape[0], 1)) self.spaces.append(numpy.asarray(spaces, dtype=bool)) self.strings.update(w.text for w in doc) + self.cats.append(doc.cats) if self.store_user_data: self.user_data.append(srsly.msgpack_dumps(doc.user_data)) @@ -102,6 +104,7 @@ class DocBin(object): words = [vocab.strings[orth] for orth in tokens[:, orth_col]] doc = Doc(vocab, words=words, spaces=spaces) doc = doc.from_array(self.attrs, tokens) + doc.cats = self.cats[i] if self.store_user_data: user_data = srsly.msgpack_loads(self.user_data[i], use_list=False) doc.user_data.update(user_data) @@ -121,6 +124,7 @@ class DocBin(object): self.tokens.extend(other.tokens) self.spaces.extend(other.spaces) self.strings.update(other.strings) + self.cats.extend(other.cats) if self.store_user_data: self.user_data.extend(other.user_data) @@ -140,6 +144,7 @@ class DocBin(object): "spaces": numpy.vstack(self.spaces).tobytes("C"), "lengths": numpy.asarray(lengths, dtype="int32").tobytes("C"), "strings": list(self.strings), + "cats": self.cats, } if self.store_user_data: msg["user_data"] = self.user_data @@ -164,6 +169,7 @@ class DocBin(object): flat_spaces = flat_spaces.reshape((flat_spaces.size, 1)) self.tokens = NumpyOps().unflatten(flat_tokens, lengths) self.spaces = NumpyOps().unflatten(flat_spaces, lengths) + self.cats = msg["cats"] if self.store_user_data and "user_data" in msg: self.user_data = list(msg["user_data"]) for tokens in self.tokens: diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index d4ba3803b..716df1087 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -887,6 +887,7 @@ cdef class Doc: "array_body": lambda: self.to_array(array_head), "sentiment": lambda: self.sentiment, "tensor": lambda: self.tensor, + "cats": lambda: self.cats, } for key in kwargs: if key in serializers or key in ("user_data", "user_data_keys", "user_data_values"): @@ -916,6 +917,7 @@ cdef class Doc: "array_body": lambda b: None, "sentiment": lambda b: None, "tensor": lambda b: None, + "cats": lambda b: None, "user_data_keys": lambda b: None, "user_data_values": lambda b: None, } @@ -937,6 +939,8 @@ cdef class Doc: self.sentiment = msg["sentiment"] if "tensor" not in exclude and "tensor" in msg: self.tensor = msg["tensor"] + if "cats" not in exclude and "cats" in msg: + self.cats = msg["cats"] start = 0 cdef const LexemeC* lex cdef unicode orth_