mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Include Doc.cats in serialization of Doc and DocBin (#4774)
* Include Doc.cats in to_bytes() * Include Doc.cats in DocBin serialization * Add tests for serialization of cats Test serialization of cats for Doc and DocBin.
This commit is contained in:
parent
e626a011cc
commit
676e75838f
|
@ -24,6 +24,7 @@ def test_serialize_empty_doc(en_vocab):
|
||||||
|
|
||||||
def test_serialize_doc_roundtrip_bytes(en_vocab):
|
def test_serialize_doc_roundtrip_bytes(en_vocab):
|
||||||
doc = Doc(en_vocab, words=["hello", "world"])
|
doc = Doc(en_vocab, words=["hello", "world"])
|
||||||
|
doc.cats = {"A": 0.5}
|
||||||
doc_b = doc.to_bytes()
|
doc_b = doc.to_bytes()
|
||||||
new_doc = Doc(en_vocab).from_bytes(doc_b)
|
new_doc = Doc(en_vocab).from_bytes(doc_b)
|
||||||
assert new_doc.to_bytes() == doc_b
|
assert new_doc.to_bytes() == doc_b
|
||||||
|
@ -66,12 +67,17 @@ def test_serialize_doc_exclude(en_vocab):
|
||||||
def test_serialize_doc_bin():
|
def test_serialize_doc_bin():
|
||||||
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
|
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
|
||||||
texts = ["Some text", "Lots of texts...", "..."]
|
texts = ["Some text", "Lots of texts...", "..."]
|
||||||
|
cats = {"A": 0.5}
|
||||||
nlp = English()
|
nlp = English()
|
||||||
for doc in nlp.pipe(texts):
|
for doc in nlp.pipe(texts):
|
||||||
|
doc.cats = cats
|
||||||
doc_bin.add(doc)
|
doc_bin.add(doc)
|
||||||
bytes_data = doc_bin.to_bytes()
|
bytes_data = doc_bin.to_bytes()
|
||||||
|
|
||||||
# Deserialize later, e.g. in a new process
|
# Deserialize later, e.g. in a new process
|
||||||
nlp = spacy.blank("en")
|
nlp = spacy.blank("en")
|
||||||
doc_bin = DocBin().from_bytes(bytes_data)
|
doc_bin = DocBin().from_bytes(bytes_data)
|
||||||
list(doc_bin.get_docs(nlp.vocab))
|
reloaded_docs = list(doc_bin.get_docs(nlp.vocab))
|
||||||
|
for i, doc in enumerate(reloaded_docs):
|
||||||
|
assert doc.text == texts[i]
|
||||||
|
assert doc.cats == cats
|
||||||
|
|
|
@ -58,6 +58,7 @@ class DocBin(object):
|
||||||
self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0]
|
self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0]
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
self.spaces = []
|
self.spaces = []
|
||||||
|
self.cats = []
|
||||||
self.user_data = []
|
self.user_data = []
|
||||||
self.strings = set()
|
self.strings = set()
|
||||||
self.store_user_data = store_user_data
|
self.store_user_data = store_user_data
|
||||||
|
@ -82,6 +83,7 @@ class DocBin(object):
|
||||||
spaces = spaces.reshape((spaces.shape[0], 1))
|
spaces = spaces.reshape((spaces.shape[0], 1))
|
||||||
self.spaces.append(numpy.asarray(spaces, dtype=bool))
|
self.spaces.append(numpy.asarray(spaces, dtype=bool))
|
||||||
self.strings.update(w.text for w in doc)
|
self.strings.update(w.text for w in doc)
|
||||||
|
self.cats.append(doc.cats)
|
||||||
if self.store_user_data:
|
if self.store_user_data:
|
||||||
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
||||||
|
|
||||||
|
@ -102,6 +104,7 @@ class DocBin(object):
|
||||||
words = [vocab.strings[orth] for orth in tokens[:, orth_col]]
|
words = [vocab.strings[orth] for orth in tokens[:, orth_col]]
|
||||||
doc = Doc(vocab, words=words, spaces=spaces)
|
doc = Doc(vocab, words=words, spaces=spaces)
|
||||||
doc = doc.from_array(self.attrs, tokens)
|
doc = doc.from_array(self.attrs, tokens)
|
||||||
|
doc.cats = self.cats[i]
|
||||||
if self.store_user_data:
|
if self.store_user_data:
|
||||||
user_data = srsly.msgpack_loads(self.user_data[i], use_list=False)
|
user_data = srsly.msgpack_loads(self.user_data[i], use_list=False)
|
||||||
doc.user_data.update(user_data)
|
doc.user_data.update(user_data)
|
||||||
|
@ -121,6 +124,7 @@ class DocBin(object):
|
||||||
self.tokens.extend(other.tokens)
|
self.tokens.extend(other.tokens)
|
||||||
self.spaces.extend(other.spaces)
|
self.spaces.extend(other.spaces)
|
||||||
self.strings.update(other.strings)
|
self.strings.update(other.strings)
|
||||||
|
self.cats.extend(other.cats)
|
||||||
if self.store_user_data:
|
if self.store_user_data:
|
||||||
self.user_data.extend(other.user_data)
|
self.user_data.extend(other.user_data)
|
||||||
|
|
||||||
|
@ -140,6 +144,7 @@ class DocBin(object):
|
||||||
"spaces": numpy.vstack(self.spaces).tobytes("C"),
|
"spaces": numpy.vstack(self.spaces).tobytes("C"),
|
||||||
"lengths": numpy.asarray(lengths, dtype="int32").tobytes("C"),
|
"lengths": numpy.asarray(lengths, dtype="int32").tobytes("C"),
|
||||||
"strings": list(self.strings),
|
"strings": list(self.strings),
|
||||||
|
"cats": self.cats,
|
||||||
}
|
}
|
||||||
if self.store_user_data:
|
if self.store_user_data:
|
||||||
msg["user_data"] = self.user_data
|
msg["user_data"] = self.user_data
|
||||||
|
@ -164,6 +169,7 @@ class DocBin(object):
|
||||||
flat_spaces = flat_spaces.reshape((flat_spaces.size, 1))
|
flat_spaces = flat_spaces.reshape((flat_spaces.size, 1))
|
||||||
self.tokens = NumpyOps().unflatten(flat_tokens, lengths)
|
self.tokens = NumpyOps().unflatten(flat_tokens, lengths)
|
||||||
self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
|
self.spaces = NumpyOps().unflatten(flat_spaces, lengths)
|
||||||
|
self.cats = msg["cats"]
|
||||||
if self.store_user_data and "user_data" in msg:
|
if self.store_user_data and "user_data" in msg:
|
||||||
self.user_data = list(msg["user_data"])
|
self.user_data = list(msg["user_data"])
|
||||||
for tokens in self.tokens:
|
for tokens in self.tokens:
|
||||||
|
|
|
@ -887,6 +887,7 @@ cdef class Doc:
|
||||||
"array_body": lambda: self.to_array(array_head),
|
"array_body": lambda: self.to_array(array_head),
|
||||||
"sentiment": lambda: self.sentiment,
|
"sentiment": lambda: self.sentiment,
|
||||||
"tensor": lambda: self.tensor,
|
"tensor": lambda: self.tensor,
|
||||||
|
"cats": lambda: self.cats,
|
||||||
}
|
}
|
||||||
for key in kwargs:
|
for key in kwargs:
|
||||||
if key in serializers or key in ("user_data", "user_data_keys", "user_data_values"):
|
if key in serializers or key in ("user_data", "user_data_keys", "user_data_values"):
|
||||||
|
@ -916,6 +917,7 @@ cdef class Doc:
|
||||||
"array_body": lambda b: None,
|
"array_body": lambda b: None,
|
||||||
"sentiment": lambda b: None,
|
"sentiment": lambda b: None,
|
||||||
"tensor": lambda b: None,
|
"tensor": lambda b: None,
|
||||||
|
"cats": lambda b: None,
|
||||||
"user_data_keys": lambda b: None,
|
"user_data_keys": lambda b: None,
|
||||||
"user_data_values": lambda b: None,
|
"user_data_values": lambda b: None,
|
||||||
}
|
}
|
||||||
|
@ -937,6 +939,8 @@ cdef class Doc:
|
||||||
self.sentiment = msg["sentiment"]
|
self.sentiment = msg["sentiment"]
|
||||||
if "tensor" not in exclude and "tensor" in msg:
|
if "tensor" not in exclude and "tensor" in msg:
|
||||||
self.tensor = msg["tensor"]
|
self.tensor = msg["tensor"]
|
||||||
|
if "cats" not in exclude and "cats" in msg:
|
||||||
|
self.cats = msg["cats"]
|
||||||
start = 0
|
start = 0
|
||||||
cdef const LexemeC* lex
|
cdef const LexemeC* lex
|
||||||
cdef unicode orth_
|
cdef unicode orth_
|
||||||
|
|
Loading…
Reference in New Issue
Block a user