mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
fix serialization of empty doc + unit test
This commit is contained in:
parent
5847be6022
commit
59000ee21d
11
spacy/tests/regression/test_issue5141.py
Normal file
11
spacy/tests/regression/test_issue5141.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from spacy.tokens import DocBin
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue5141(en_vocab):
|
||||||
|
""" Ensure an empty DocBin does not crash on serialization """
|
||||||
|
doc_bin = DocBin(attrs=["DEP", "HEAD"])
|
||||||
|
assert list(doc_bin.get_docs(en_vocab)) == []
|
||||||
|
doc_bin_bytes = doc_bin.to_bytes()
|
||||||
|
|
||||||
|
doc_bin_2 = DocBin().from_bytes(doc_bin_bytes)
|
||||||
|
assert list(doc_bin_2.get_docs(en_vocab)) == []
|
|
@ -135,10 +135,13 @@ class DocBin(object):
|
||||||
for tokens in self.tokens:
|
for tokens in self.tokens:
|
||||||
assert len(tokens.shape) == 2, tokens.shape # this should never happen
|
assert len(tokens.shape) == 2, tokens.shape # this should never happen
|
||||||
lengths = [len(tokens) for tokens in self.tokens]
|
lengths = [len(tokens) for tokens in self.tokens]
|
||||||
|
tokens = numpy.vstack(self.tokens) if self.tokens else numpy.asarray([])
|
||||||
|
spaces = numpy.vstack(self.spaces) if self.spaces else numpy.asarray([])
|
||||||
|
|
||||||
msg = {
|
msg = {
|
||||||
"attrs": self.attrs,
|
"attrs": self.attrs,
|
||||||
"tokens": numpy.vstack(self.tokens).tobytes("C"),
|
"tokens": tokens.tobytes("C"),
|
||||||
"spaces": numpy.vstack(self.spaces).tobytes("C"),
|
"spaces": spaces.tobytes("C"),
|
||||||
"lengths": numpy.asarray(lengths, dtype="int32").tobytes("C"),
|
"lengths": numpy.asarray(lengths, dtype="int32").tobytes("C"),
|
||||||
"strings": list(self.strings),
|
"strings": list(self.strings),
|
||||||
"cats": self.cats,
|
"cats": self.cats,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user