mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
avoid None callback (#6100)
This commit is contained in:
parent
19fc72e4cd
commit
d53c84b6d6
|
@ -127,7 +127,7 @@ class Tok2Vec(Pipe):
|
|||
tokvecs = self.model.predict(docs)
|
||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||
for listener in self.listeners:
|
||||
listener.receive(batch_id, tokvecs, None)
|
||||
listener.receive(batch_id, tokvecs, lambda dX: [])
|
||||
return tokvecs
|
||||
|
||||
def set_annotations(self, docs: Sequence[Doc], tokvecses) -> None:
|
||||
|
|
|
@ -169,3 +169,22 @@ def test_tok2vec_listener():
|
|||
nlp.select_pipes(disable="tok2vec")
|
||||
assert nlp.pipe_names == ["tagger"]
|
||||
nlp("Running the pipeline with the Tok2Vec component disabled.")
|
||||
|
||||
|
||||
def test_tok2vec_listener_callback():
|
||||
orig_config = Config().from_str(cfg_string)
|
||||
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||
tagger = nlp.get_pipe("tagger")
|
||||
tok2vec = nlp.get_pipe("tok2vec")
|
||||
nlp._link_components()
|
||||
docs = [nlp.make_doc("A random sentence")]
|
||||
tok2vec.model.initialize(X=docs)
|
||||
gold_array = [[1.0 for tag in ["V", "Z"]] for word in docs]
|
||||
label_sample = [tagger.model.ops.asarray(gold_array, dtype="float32")]
|
||||
tagger.model.initialize(X=docs, Y=label_sample)
|
||||
docs = [nlp.make_doc("Another entirely random sentence")]
|
||||
tok2vec.predict(docs)
|
||||
Y, get_dX = tagger.model.begin_update(docs)
|
||||
# assure that the backprop call works (and doesn't hit a 'None' callback)
|
||||
assert get_dX(Y) is not None
|
||||
|
|
Loading…
Reference in New Issue
Block a user