Pickle Doc._context (#9603)

This commit is contained in:
Adriane Boyd 2021-11-03 09:14:29 +01:00 committed by GitHub
parent 61daac54e4
commit 6eee024ff6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 2 deletions

View File

@ -5,9 +5,11 @@ from spacy.compat import pickle
def test_pickle_single_doc(): def test_pickle_single_doc():
nlp = Language() nlp = Language()
doc = nlp("pickle roundtrip") doc = nlp("pickle roundtrip")
doc._context = 3
data = pickle.dumps(doc, 1) data = pickle.dumps(doc, 1)
doc2 = pickle.loads(data) doc2 = pickle.loads(data)
assert doc2.text == "pickle roundtrip" assert doc2.text == "pickle roundtrip"
assert doc2._context == 3
def test_list_of_docs_pickles_efficiently(): def test_list_of_docs_pickles_efficiently():

View File

@ -1710,17 +1710,18 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
def pickle_doc(doc): def pickle_doc(doc):
bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"]) bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"])
hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks, hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks,
doc.user_token_hooks) doc.user_token_hooks, doc._context)
return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data)) return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data))
def unpickle_doc(vocab, hooks_and_data, bytes_data): def unpickle_doc(vocab, hooks_and_data, bytes_data):
user_data, doc_hooks, span_hooks, token_hooks = srsly.pickle_loads(hooks_and_data) user_data, doc_hooks, span_hooks, token_hooks, _context = srsly.pickle_loads(hooks_and_data)
doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data, exclude=["user_data"]) doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data, exclude=["user_data"])
doc.user_hooks.update(doc_hooks) doc.user_hooks.update(doc_hooks)
doc.user_span_hooks.update(span_hooks) doc.user_span_hooks.update(span_hooks)
doc.user_token_hooks.update(token_hooks) doc.user_token_hooks.update(token_hooks)
doc._context = _context
return doc return doc