Don't serialize user data in DocBin if not saving it (fix #9190) (#9226)

* Don't store user data if told not to (fix #9190)

* Add unit tests for the store_user_data setting
This commit is contained in:
Paul O'Leary McCann 2021-10-01 10:37:39 +00:00 committed by GitHub
parent 03fefa37e2
commit 8f2409e514
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 1 deletions

View File

@ -524,6 +524,30 @@ def test_roundtrip_docs_to_docbin(doc):
assert cats["TRAVEL"] == reloaded_example.reference.cats["TRAVEL"]
assert cats["BAKING"] == reloaded_example.reference.cats["BAKING"]
def test_docbin_user_data_serialized(doc):
doc.user_data["check"] = True
nlp = English()
with make_tempdir() as tmpdir:
output_file = tmpdir / "userdata.spacy"
DocBin(docs=[doc], store_user_data=True).to_disk(output_file)
reloaded_docs = DocBin().from_disk(output_file).get_docs(nlp.vocab)
reloaded_doc = list(reloaded_docs)[0]
assert reloaded_doc.user_data["check"] == True
def test_docbin_user_data_not_serialized(doc):
# this isn't serializable, but that shouldn't cause an error
doc.user_data["check"] = set()
nlp = English()
with make_tempdir() as tmpdir:
output_file = tmpdir / "userdata.spacy"
DocBin(docs=[doc], store_user_data=False).to_disk(output_file)
reloaded_docs = DocBin().from_disk(output_file).get_docs(nlp.vocab)
reloaded_doc = list(reloaded_docs)[0]
assert "check" not in reloaded_doc.user_data
@pytest.mark.parametrize(
"tokens_a,tokens_b,expected",

View File

@ -110,6 +110,7 @@ class DocBin:
self.strings.add(token.ent_kb_id_)
self.strings.add(token.ent_id_)
self.cats.append(doc.cats)
if self.store_user_data:
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
self.span_groups.append(doc.spans.to_bytes())
for key, group in doc.spans.items():