mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
* Don't store user data if told not to (fix #9190) * Add unit tests for the store_user_data setting
This commit is contained in:
parent
03fefa37e2
commit
8f2409e514
|
@ -524,6 +524,30 @@ def test_roundtrip_docs_to_docbin(doc):
|
|||
assert cats["TRAVEL"] == reloaded_example.reference.cats["TRAVEL"]
|
||||
assert cats["BAKING"] == reloaded_example.reference.cats["BAKING"]
|
||||
|
||||
def test_docbin_user_data_serialized(doc):
|
||||
doc.user_data["check"] = True
|
||||
nlp = English()
|
||||
|
||||
with make_tempdir() as tmpdir:
|
||||
output_file = tmpdir / "userdata.spacy"
|
||||
DocBin(docs=[doc], store_user_data=True).to_disk(output_file)
|
||||
reloaded_docs = DocBin().from_disk(output_file).get_docs(nlp.vocab)
|
||||
reloaded_doc = list(reloaded_docs)[0]
|
||||
|
||||
assert reloaded_doc.user_data["check"] == True
|
||||
|
||||
def test_docbin_user_data_not_serialized(doc):
|
||||
# this isn't serializable, but that shouldn't cause an error
|
||||
doc.user_data["check"] = set()
|
||||
nlp = English()
|
||||
|
||||
with make_tempdir() as tmpdir:
|
||||
output_file = tmpdir / "userdata.spacy"
|
||||
DocBin(docs=[doc], store_user_data=False).to_disk(output_file)
|
||||
reloaded_docs = DocBin().from_disk(output_file).get_docs(nlp.vocab)
|
||||
reloaded_doc = list(reloaded_docs)[0]
|
||||
|
||||
assert "check" not in reloaded_doc.user_data
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens_a,tokens_b,expected",
|
||||
|
|
|
@ -110,7 +110,8 @@ class DocBin:
|
|||
self.strings.add(token.ent_kb_id_)
|
||||
self.strings.add(token.ent_id_)
|
||||
self.cats.append(doc.cats)
|
||||
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
||||
if self.store_user_data:
|
||||
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
||||
self.span_groups.append(doc.spans.to_bytes())
|
||||
for key, group in doc.spans.items():
|
||||
for span in group:
|
||||
|
|
Loading…
Reference in New Issue
Block a user