From d53c84b6d6717375ee91d2847a3d0f24beafd8d1 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Tue, 22 Sep 2020 13:54:44 +0200 Subject: [PATCH] avoid None callback (#6100) --- spacy/pipeline/tok2vec.py | 2 +- spacy/tests/pipeline/test_tok2vec.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index 721c67a19..9ab4e42b7 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -127,7 +127,7 @@ class Tok2Vec(Pipe): tokvecs = self.model.predict(docs) batch_id = Tok2VecListener.get_batch_id(docs) for listener in self.listeners: - listener.receive(batch_id, tokvecs, None) + listener.receive(batch_id, tokvecs, lambda dX: []) return tokvecs def set_annotations(self, docs: Sequence[Doc], tokvecses) -> None: diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index 2e514f490..6041657d3 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -169,3 +169,22 @@ def test_tok2vec_listener(): nlp.select_pipes(disable="tok2vec") assert nlp.pipe_names == ["tagger"] nlp("Running the pipeline with the Tok2Vec component disabled.") + + +def test_tok2vec_listener_callback(): + orig_config = Config().from_str(cfg_string) + nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True) + assert nlp.pipe_names == ["tok2vec", "tagger"] + tagger = nlp.get_pipe("tagger") + tok2vec = nlp.get_pipe("tok2vec") + nlp._link_components() + docs = [nlp.make_doc("A random sentence")] + tok2vec.model.initialize(X=docs) + gold_array = [[1.0 for tag in ["V", "Z"]] for word in docs] + label_sample = [tagger.model.ops.asarray(gold_array, dtype="float32")] + tagger.model.initialize(X=docs, Y=label_sample) + docs = [nlp.make_doc("Another entirely random sentence")] + tok2vec.predict(docs) + Y, get_dX = tagger.model.begin_update(docs) + # assure that the backprop call works (and doesn't hit a 'None' callback) + assert get_dX(Y) is not None