From 5ca57d822140a84bb790496f3b797e8374573b95 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 29 Dec 2020 11:54:32 +0100 Subject: [PATCH] Add logger warning when serializing user hooks (#6595) Add a warning that user hooks are lost on serialization. Add a `user_hooks` exclude to skip the warning with pickle. --- spacy/errors.py | 2 ++ spacy/tests/doc/test_doc_api.py | 13 +++++++++++++ spacy/tokens/doc.pyx | 4 +++- 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/spacy/errors.py b/spacy/errors.py index 332e7df1c..6d3591ab5 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -123,6 +123,8 @@ class Warnings: "token '{text}'. Check that your pipeline includes components that " "assign token.pos, typically 'tagger'+'attribute_ruler' or " "'morphologizer'.") + W109 = ("Unable to save user hooks while serializing the doc. Re-add any " + "required user hooks to the doc after processing.") @add_codes diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index e5ed359ee..04eb0e092 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -1,5 +1,7 @@ import pytest import numpy +import logging +import mock from spacy.tokens import Doc, Span from spacy.vocab import Vocab from spacy.lexeme import Lexeme @@ -147,6 +149,17 @@ def test_doc_api_serialize(en_tokenizer, text): assert [t.text for t in tokens] == [t.text for t in new_tokens] assert [t.orth for t in tokens] == [t.orth for t in new_tokens] + def inner_func(d1, d2): + return "hello!" + + logger = logging.getLogger("spacy") + with mock.patch.object(logger, "warning") as mock_warning: + _ = tokens.to_bytes() + mock_warning.assert_not_called() + tokens.user_hooks["similarity"] = inner_func + _ = tokens.to_bytes() + mock_warning.assert_called_once() + def test_doc_api_set_ents(en_tokenizer): text = "I use goggle chrone to surf the web" diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index c824b2752..2f9d2e015 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1273,6 +1273,8 @@ cdef class Doc: serializers["user_data_keys"] = lambda: srsly.msgpack_dumps(user_data_keys) if "user_data_values" not in exclude: serializers["user_data_values"] = lambda: srsly.msgpack_dumps(user_data_values) + if "user_hooks" not in exclude and any((self.user_hooks, self.user_token_hooks, self.user_span_hooks)): + util.logger.warning(Warnings.W109) return util.to_dict(serializers, exclude) def from_dict(self, msg, *, exclude=tuple()): @@ -1649,7 +1651,7 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end): def pickle_doc(doc): - bytes_data = doc.to_bytes(exclude=["vocab", "user_data"]) + bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"]) hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks, doc.user_token_hooks) return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data))