mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Include Doc.cats in serialization of Doc and DocBin (#4774)
* Include Doc.cats in to_bytes() * Include Doc.cats in DocBin serialization * Add tests for serialization of cats Test serialization of cats for Doc and DocBin.
This commit is contained in:
parent
e626a011cc
commit
676e75838f
|
@ -24,6 +24,7 @@ def test_serialize_empty_doc(en_vocab):
|
|||
|
||||
def test_serialize_doc_roundtrip_bytes(en_vocab):
|
||||
doc = Doc(en_vocab, words=["hello", "world"])
|
||||
doc.cats = {"A": 0.5}
|
||||
doc_b = doc.to_bytes()
|
||||
new_doc = Doc(en_vocab).from_bytes(doc_b)
|
||||
assert new_doc.to_bytes() == doc_b
|
||||
|
@ -66,12 +67,17 @@ def test_serialize_doc_exclude(en_vocab):
|
|||
def test_serialize_doc_bin():
|
||||
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
|
||||
texts = ["Some text", "Lots of texts...", "..."]
|
||||
cats = {"A": 0.5}
|
||||
nlp = English()
|
||||
for doc in nlp.pipe(texts):
|
||||
doc.cats = cats
|
||||
doc_bin.add(doc)
|
||||
bytes_data = doc_bin.to_bytes()
|
||||
|
||||
# Deserialize later, e.g. in a new process
|
||||
nlp = spacy.blank("en")
|
||||
doc_bin = DocBin().from_bytes(bytes_data)
|
||||
list(doc_bin.get_docs(nlp.vocab))
|
||||
reloaded_docs = list(doc_bin.get_docs(nlp.vocab))
|
||||
for i, doc in enumerate(reloaded_docs):
|
||||
assert doc.text == texts[i]
|
||||
assert doc.cats == cats
|
||||
|
|
|
@ -58,6 +58,7 @@ class DocBin(object):
|
|||
self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0]
|
||||
self.tokens = []
|
||||
self.spaces = []
|
||||
self.cats = []
|
||||
self.user_data = []
|
||||
self.strings = set()
|
||||
self.store_user_data = store_user_data
|
||||
|
@ -82,6 +83,7 @@ class DocBin(object):
|
|||
spaces = spaces.reshape((spaces.shape[0], 1))
|
||||
self.spaces.append(numpy.asarray(spaces, dtype=bool))
|
||||
self.strings.update(w.text for w in doc)
|
||||
self.cats.append(doc.cats)
|
||||
if self.store_user_data:
|
||||
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
||||
|
||||
|
@ -102,6 +104,7 @@ class DocBin(object):
|
|||
words = [vocab.strings[orth] for orth in tokens[:, orth_col]]
|
||||
doc = Doc(vocab, words=words, spaces=spaces)
|
||||
doc = doc.from_array(self.attrs, tokens)
|
||||
doc.cats = self.cats[i]
|
||||
if self.store_user_data:
|
||||
user_data = srsly.msgpack_loads(self.user_data[i], use_list=False)
|
||||
doc.user_data.update(user_data)
|
||||
|
@ -121,6 +124,7 @@ class DocBin(object):
|
|||
self.tokens.extend(other.tokens)
|
||||
self.spaces.extend(other.spaces)
|
||||
self.strings.update(other.strings)
|
||||
self.cats.extend(other.cats)
|
||||
if self.store_user_data:
|
||||
self.user_data.extend(other.user_data)
|
||||
|
||||
|
@ -140,6 +144,7 @@ class DocBin(object):
|
|||
"spaces": numpy.vstack(self.spaces).tobytes("C"),
|
||||
"lengths": numpy.asarray(lengths, dtype="int32").tobytes("C"),
|
||||
"strings": list(self.strings),
|
||||
"cats": self.cats,
|
||||
}
|
||||
if self.store_user_data:
|
||||
msg["user_data"] = self.user_data
|
||||
|
@ -164,6 +169,7 @@ class DocBin(object):
|
|||
flat_spaces = flat_spaces.reshape((flat_spaces.size, 1))
|
||||
self.tokens = NumpyOps().unflatten(flat_tokens, lengths)
|
||||
self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
|
||||
self.cats = msg["cats"]
|
||||
if self.store_user_data and "user_data" in msg:
|
||||
self.user_data = list(msg["user_data"])
|
||||
for tokens in self.tokens:
|
||||
|
|
|
@ -887,6 +887,7 @@ cdef class Doc:
|
|||
"array_body": lambda: self.to_array(array_head),
|
||||
"sentiment": lambda: self.sentiment,
|
||||
"tensor": lambda: self.tensor,
|
||||
"cats": lambda: self.cats,
|
||||
}
|
||||
for key in kwargs:
|
||||
if key in serializers or key in ("user_data", "user_data_keys", "user_data_values"):
|
||||
|
@ -916,6 +917,7 @@ cdef class Doc:
|
|||
"array_body": lambda b: None,
|
||||
"sentiment": lambda b: None,
|
||||
"tensor": lambda b: None,
|
||||
"cats": lambda b: None,
|
||||
"user_data_keys": lambda b: None,
|
||||
"user_data_values": lambda b: None,
|
||||
}
|
||||
|
@ -937,6 +939,8 @@ cdef class Doc:
|
|||
self.sentiment = msg["sentiment"]
|
||||
if "tensor" not in exclude and "tensor" in msg:
|
||||
self.tensor = msg["tensor"]
|
||||
if "cats" not in exclude and "cats" in msg:
|
||||
self.cats = msg["cats"]
|
||||
start = 0
|
||||
cdef const LexemeC* lex
|
||||
cdef unicode orth_
|
||||
|
|
Loading…
Reference in New Issue
Block a user