spaCy/spacy/tests/serialize/test_serialize_docbin.py
Lj Miranda 7d50804644
Migrate regression tests into the main test suite ()
* Migrate regressions 1-1000

* Move serialize test to correct file

* Remove tests that won't work in v3

* Migrate regressions 1000-1500

Removed regression test 1250 because v3 doesn't support the old LEX
scheme anymore.

* Add missing imports in serializer tests

* Migrate tests 1500-2000

* Migrate regressions from 2000-2500

* Migrate regressions from 2501-3000

* Migrate regressions from 3000-3501

* Migrate regressions from 3501-4000

* Migrate regressions from 4001-4500

* Migrate regressions from 4501-5000

* Migrate regressions from 5001-5501

* Migrate regressions from 5501 to 7000

* Migrate regressions from 7001 to 8000

* Migrate remaining regression tests

* Fixing missing imports

* Update docs with new system [ci skip]

* Update CONTRIBUTING.md

- Fix formatting
- Update wording

* Remove lemmatizer tests in el lang

* Move a few tests into the general tokenizer

* Separate Doc and DocBin tests
2021-12-04 20:34:48 +01:00

107 lines
3.6 KiB
Python

import pytest
import spacy
from spacy.lang.en import English
from spacy.tokens import Doc, DocBin
from spacy.tokens.underscore import Underscore
@pytest.mark.issue(4367)
def test_issue4367():
"""Test that docbin init goes well"""
DocBin()
DocBin(attrs=["LEMMA"])
DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"])
@pytest.mark.issue(4528)
def test_issue4528(en_vocab):
"""Test that user_data is correctly serialized in DocBin."""
doc = Doc(en_vocab, words=["hello", "world"])
doc.user_data["foo"] = "bar"
# This is how extension attribute values are stored in the user data
doc.user_data[("._.", "foo", None, None)] = "bar"
doc_bin = DocBin(store_user_data=True)
doc_bin.add(doc)
doc_bin_bytes = doc_bin.to_bytes()
new_doc_bin = DocBin(store_user_data=True).from_bytes(doc_bin_bytes)
new_doc = list(new_doc_bin.get_docs(en_vocab))[0]
assert new_doc.user_data["foo"] == "bar"
assert new_doc.user_data[("._.", "foo", None, None)] == "bar"
@pytest.mark.issue(5141)
def test_issue5141(en_vocab):
"""Ensure an empty DocBin does not crash on serialization"""
doc_bin = DocBin(attrs=["DEP", "HEAD"])
assert list(doc_bin.get_docs(en_vocab)) == []
doc_bin_bytes = doc_bin.to_bytes()
doc_bin_2 = DocBin().from_bytes(doc_bin_bytes)
assert list(doc_bin_2.get_docs(en_vocab)) == []
def test_serialize_doc_bin():
doc_bin = DocBin(
attrs=["LEMMA", "ENT_IOB", "ENT_TYPE", "NORM", "ENT_ID"], store_user_data=True
)
texts = ["Some text", "Lots of texts...", "..."]
cats = {"A": 0.5}
nlp = English()
for doc in nlp.pipe(texts):
doc.cats = cats
doc.spans["start"] = [doc[0:2]]
doc[0].norm_ = "UNUSUAL_TOKEN_NORM"
doc[0].ent_id_ = "UNUSUAL_TOKEN_ENT_ID"
doc_bin.add(doc)
bytes_data = doc_bin.to_bytes()
# Deserialize later, e.g. in a new process
nlp = spacy.blank("en")
doc_bin = DocBin().from_bytes(bytes_data)
reloaded_docs = list(doc_bin.get_docs(nlp.vocab))
for i, doc in enumerate(reloaded_docs):
assert doc.text == texts[i]
assert doc.cats == cats
assert len(doc.spans) == 1
assert doc[0].norm_ == "UNUSUAL_TOKEN_NORM"
assert doc[0].ent_id_ == "UNUSUAL_TOKEN_ENT_ID"
def test_serialize_doc_bin_unknown_spaces(en_vocab):
doc1 = Doc(en_vocab, words=["that", "'s"])
assert doc1.has_unknown_spaces
assert doc1.text == "that 's "
doc2 = Doc(en_vocab, words=["that", "'s"], spaces=[False, False])
assert not doc2.has_unknown_spaces
assert doc2.text == "that's"
doc_bin = DocBin().from_bytes(DocBin(docs=[doc1, doc2]).to_bytes())
re_doc1, re_doc2 = doc_bin.get_docs(en_vocab)
assert re_doc1.has_unknown_spaces
assert re_doc1.text == "that 's "
assert not re_doc2.has_unknown_spaces
assert re_doc2.text == "that's"
@pytest.mark.parametrize(
"writer_flag,reader_flag,reader_value",
[
(True, True, "bar"),
(True, False, "bar"),
(False, True, "nothing"),
(False, False, "nothing"),
],
)
def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_value):
"""Test that custom extensions are correctly serialized in DocBin."""
Doc.set_extension("foo", default="nothing")
doc = Doc(en_vocab, words=["hello", "world"])
doc._.foo = "bar"
doc_bin_1 = DocBin(store_user_data=writer_flag)
doc_bin_1.add(doc)
doc_bin_bytes = doc_bin_1.to_bytes()
doc_bin_2 = DocBin(store_user_data=reader_flag).from_bytes(doc_bin_bytes)
doc_2 = list(doc_bin_2.get_docs(en_vocab))[0]
assert doc_2._.foo == reader_value
Underscore.doc_extensions = {}