1
1
mirror of https://github.com/explosion/spaCy.git synced 2025-01-28 02:04:07 +03:00
spaCy/spacy/tests/serialize/test_serialize_doc.py
adrianeboyd 676e75838f Include Doc.cats in serialization of Doc and DocBin ()
* Include Doc.cats in to_bytes()

* Include Doc.cats in DocBin serialization

* Add tests for serialization of cats

Test serialization of cats for Doc and DocBin.
2019-12-06 14:07:39 +01:00

84 lines
2.5 KiB
Python

# coding: utf-8
from __future__ import unicode_literals
import spacy
import pytest
from spacy.lang.en import English
from spacy.tokens import Doc, DocBin
from spacy.compat import path2str
from ..util import make_tempdir
def test_serialize_empty_doc(en_vocab):
doc = Doc(en_vocab)
data = doc.to_bytes()
doc2 = Doc(en_vocab)
doc2.from_bytes(data)
assert len(doc) == len(doc2)
for token1, token2 in zip(doc, doc2):
assert token1.text == token2.text
def test_serialize_doc_roundtrip_bytes(en_vocab):
doc = Doc(en_vocab, words=["hello", "world"])
doc.cats = {"A": 0.5}
doc_b = doc.to_bytes()
new_doc = Doc(en_vocab).from_bytes(doc_b)
assert new_doc.to_bytes() == doc_b
def test_serialize_doc_roundtrip_disk(en_vocab):
doc = Doc(en_vocab, words=["hello", "world"])
with make_tempdir() as d:
file_path = d / "doc"
doc.to_disk(file_path)
doc_d = Doc(en_vocab).from_disk(file_path)
assert doc.to_bytes() == doc_d.to_bytes()
def test_serialize_doc_roundtrip_disk_str_path(en_vocab):
doc = Doc(en_vocab, words=["hello", "world"])
with make_tempdir() as d:
file_path = d / "doc"
file_path = path2str(file_path)
doc.to_disk(file_path)
doc_d = Doc(en_vocab).from_disk(file_path)
assert doc.to_bytes() == doc_d.to_bytes()
def test_serialize_doc_exclude(en_vocab):
doc = Doc(en_vocab, words=["hello", "world"])
doc.user_data["foo"] = "bar"
new_doc = Doc(en_vocab).from_bytes(doc.to_bytes())
assert new_doc.user_data["foo"] == "bar"
new_doc = Doc(en_vocab).from_bytes(doc.to_bytes(), exclude=["user_data"])
assert not new_doc.user_data
new_doc = Doc(en_vocab).from_bytes(doc.to_bytes(exclude=["user_data"]))
assert not new_doc.user_data
with pytest.raises(ValueError):
doc.to_bytes(user_data=False)
with pytest.raises(ValueError):
Doc(en_vocab).from_bytes(doc.to_bytes(), tensor=False)
def test_serialize_doc_bin():
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True)
texts = ["Some text", "Lots of texts...", "..."]
cats = {"A": 0.5}
nlp = English()
for doc in nlp.pipe(texts):
doc.cats = cats
doc_bin.add(doc)
bytes_data = doc_bin.to_bytes()
# Deserialize later, e.g. in a new process
nlp = spacy.blank("en")
doc_bin = DocBin().from_bytes(bytes_data)
reloaded_docs = list(doc_bin.get_docs(nlp.vocab))
for i, doc in enumerate(reloaded_docs):
assert doc.text == texts[i]
assert doc.cats == cats