Add logger warning when serializing user hooks (#6595)

Add a warning that user hooks are lost on serialization.

Add a `user_hooks` exclude to skip the warning with pickle.
This commit is contained in:
Adriane Boyd 2020-12-29 11:54:32 +01:00 committed by GitHub
parent cabd4ae5b1
commit 5ca57d8221
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 1 deletions

View File

@ -123,6 +123,8 @@ class Warnings:
"token '{text}'. Check that your pipeline includes components that "
"assign token.pos, typically 'tagger'+'attribute_ruler' or "
"'morphologizer'.")
W109 = ("Unable to save user hooks while serializing the doc. Re-add any "
"required user hooks to the doc after processing.")
@add_codes

View File

@ -1,5 +1,7 @@
import pytest
import numpy
import logging
import mock
from spacy.tokens import Doc, Span
from spacy.vocab import Vocab
from spacy.lexeme import Lexeme
@ -147,6 +149,17 @@ def test_doc_api_serialize(en_tokenizer, text):
assert [t.text for t in tokens] == [t.text for t in new_tokens]
assert [t.orth for t in tokens] == [t.orth for t in new_tokens]
def inner_func(d1, d2):
return "hello!"
logger = logging.getLogger("spacy")
with mock.patch.object(logger, "warning") as mock_warning:
_ = tokens.to_bytes()
mock_warning.assert_not_called()
tokens.user_hooks["similarity"] = inner_func
_ = tokens.to_bytes()
mock_warning.assert_called_once()
def test_doc_api_set_ents(en_tokenizer):
text = "I use goggle chrone to surf the web"

View File

@ -1273,6 +1273,8 @@ cdef class Doc:
serializers["user_data_keys"] = lambda: srsly.msgpack_dumps(user_data_keys)
if "user_data_values" not in exclude:
serializers["user_data_values"] = lambda: srsly.msgpack_dumps(user_data_values)
if "user_hooks" not in exclude and any((self.user_hooks, self.user_token_hooks, self.user_span_hooks)):
util.logger.warning(Warnings.W109)
return util.to_dict(serializers, exclude)
def from_dict(self, msg, *, exclude=tuple()):
@ -1649,7 +1651,7 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
def pickle_doc(doc):
bytes_data = doc.to_bytes(exclude=["vocab", "user_data"])
bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"])
hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks,
doc.user_token_hooks)
return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data))