Add logger warning when serializing user hooks (#6595)

Add a warning that user hooks are lost on serialization.

Add a `user_hooks` exclude to skip the warning with pickle.
This commit is contained in:
Adriane Boyd 2020-12-29 11:54:32 +01:00 committed by GitHub
parent cabd4ae5b1
commit 5ca57d8221
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 1 deletions

View File

@ -123,6 +123,8 @@ class Warnings:
"token '{text}'. Check that your pipeline includes components that " "token '{text}'. Check that your pipeline includes components that "
"assign token.pos, typically 'tagger'+'attribute_ruler' or " "assign token.pos, typically 'tagger'+'attribute_ruler' or "
"'morphologizer'.") "'morphologizer'.")
W109 = ("Unable to save user hooks while serializing the doc. Re-add any "
"required user hooks to the doc after processing.")
@add_codes @add_codes

View File

@ -1,5 +1,7 @@
import pytest import pytest
import numpy import numpy
import logging
import mock
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.lexeme import Lexeme from spacy.lexeme import Lexeme
@ -147,6 +149,17 @@ def test_doc_api_serialize(en_tokenizer, text):
assert [t.text for t in tokens] == [t.text for t in new_tokens] assert [t.text for t in tokens] == [t.text for t in new_tokens]
assert [t.orth for t in tokens] == [t.orth for t in new_tokens] assert [t.orth for t in tokens] == [t.orth for t in new_tokens]
def inner_func(d1, d2):
return "hello!"
logger = logging.getLogger("spacy")
with mock.patch.object(logger, "warning") as mock_warning:
_ = tokens.to_bytes()
mock_warning.assert_not_called()
tokens.user_hooks["similarity"] = inner_func
_ = tokens.to_bytes()
mock_warning.assert_called_once()
def test_doc_api_set_ents(en_tokenizer): def test_doc_api_set_ents(en_tokenizer):
text = "I use goggle chrone to surf the web" text = "I use goggle chrone to surf the web"

View File

@ -1273,6 +1273,8 @@ cdef class Doc:
serializers["user_data_keys"] = lambda: srsly.msgpack_dumps(user_data_keys) serializers["user_data_keys"] = lambda: srsly.msgpack_dumps(user_data_keys)
if "user_data_values" not in exclude: if "user_data_values" not in exclude:
serializers["user_data_values"] = lambda: srsly.msgpack_dumps(user_data_values) serializers["user_data_values"] = lambda: srsly.msgpack_dumps(user_data_values)
if "user_hooks" not in exclude and any((self.user_hooks, self.user_token_hooks, self.user_span_hooks)):
util.logger.warning(Warnings.W109)
return util.to_dict(serializers, exclude) return util.to_dict(serializers, exclude)
def from_dict(self, msg, *, exclude=tuple()): def from_dict(self, msg, *, exclude=tuple()):
@ -1649,7 +1651,7 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
def pickle_doc(doc): def pickle_doc(doc):
bytes_data = doc.to_bytes(exclude=["vocab", "user_data"]) bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"])
hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks, hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks,
doc.user_token_hooks) doc.user_token_hooks)
return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data)) return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data))