avoid None callback (#6100)

This commit is contained in:
Sofie Van Landeghem 2020-09-22 13:54:44 +02:00 committed by GitHub
parent 19fc72e4cd
commit d53c84b6d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 1 deletions

View File

@ -127,7 +127,7 @@ class Tok2Vec(Pipe):
tokvecs = self.model.predict(docs)
batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners:
listener.receive(batch_id, tokvecs, None)
listener.receive(batch_id, tokvecs, lambda dX: [])
return tokvecs
def set_annotations(self, docs: Sequence[Doc], tokvecses) -> None:

View File

@ -169,3 +169,22 @@ def test_tok2vec_listener():
nlp.select_pipes(disable="tok2vec")
assert nlp.pipe_names == ["tagger"]
nlp("Running the pipeline with the Tok2Vec component disabled.")
def test_tok2vec_listener_callback():
orig_config = Config().from_str(cfg_string)
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
assert nlp.pipe_names == ["tok2vec", "tagger"]
tagger = nlp.get_pipe("tagger")
tok2vec = nlp.get_pipe("tok2vec")
nlp._link_components()
docs = [nlp.make_doc("A random sentence")]
tok2vec.model.initialize(X=docs)
gold_array = [[1.0 for tag in ["V", "Z"]] for word in docs]
label_sample = [tagger.model.ops.asarray(gold_array, dtype="float32")]
tagger.model.initialize(X=docs, Y=label_sample)
docs = [nlp.make_doc("Another entirely random sentence")]
tok2vec.predict(docs)
Y, get_dX = tagger.model.begin_update(docs)
# assure that the backprop call works (and doesn't hit a 'None' callback)
assert get_dX(Y) is not None