mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
* Don't store user data if told not to (fix #9190) * Add unit tests for the store_user_data setting
This commit is contained in:
parent
03fefa37e2
commit
8f2409e514
|
@ -524,6 +524,30 @@ def test_roundtrip_docs_to_docbin(doc):
|
||||||
assert cats["TRAVEL"] == reloaded_example.reference.cats["TRAVEL"]
|
assert cats["TRAVEL"] == reloaded_example.reference.cats["TRAVEL"]
|
||||||
assert cats["BAKING"] == reloaded_example.reference.cats["BAKING"]
|
assert cats["BAKING"] == reloaded_example.reference.cats["BAKING"]
|
||||||
|
|
||||||
|
def test_docbin_user_data_serialized(doc):
|
||||||
|
doc.user_data["check"] = True
|
||||||
|
nlp = English()
|
||||||
|
|
||||||
|
with make_tempdir() as tmpdir:
|
||||||
|
output_file = tmpdir / "userdata.spacy"
|
||||||
|
DocBin(docs=[doc], store_user_data=True).to_disk(output_file)
|
||||||
|
reloaded_docs = DocBin().from_disk(output_file).get_docs(nlp.vocab)
|
||||||
|
reloaded_doc = list(reloaded_docs)[0]
|
||||||
|
|
||||||
|
assert reloaded_doc.user_data["check"] == True
|
||||||
|
|
||||||
|
def test_docbin_user_data_not_serialized(doc):
|
||||||
|
# this isn't serializable, but that shouldn't cause an error
|
||||||
|
doc.user_data["check"] = set()
|
||||||
|
nlp = English()
|
||||||
|
|
||||||
|
with make_tempdir() as tmpdir:
|
||||||
|
output_file = tmpdir / "userdata.spacy"
|
||||||
|
DocBin(docs=[doc], store_user_data=False).to_disk(output_file)
|
||||||
|
reloaded_docs = DocBin().from_disk(output_file).get_docs(nlp.vocab)
|
||||||
|
reloaded_doc = list(reloaded_docs)[0]
|
||||||
|
|
||||||
|
assert "check" not in reloaded_doc.user_data
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tokens_a,tokens_b,expected",
|
"tokens_a,tokens_b,expected",
|
||||||
|
|
|
@ -110,7 +110,8 @@ class DocBin:
|
||||||
self.strings.add(token.ent_kb_id_)
|
self.strings.add(token.ent_kb_id_)
|
||||||
self.strings.add(token.ent_id_)
|
self.strings.add(token.ent_id_)
|
||||||
self.cats.append(doc.cats)
|
self.cats.append(doc.cats)
|
||||||
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
if self.store_user_data:
|
||||||
|
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
||||||
self.span_groups.append(doc.spans.to_bytes())
|
self.span_groups.append(doc.spans.to_bytes())
|
||||||
for key, group in doc.spans.items():
|
for key, group in doc.spans.items():
|
||||||
for span in group:
|
for span in group:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user