mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
avoid None callback (#6100)
This commit is contained in:
parent
19fc72e4cd
commit
d53c84b6d6
|
@ -127,7 +127,7 @@ class Tok2Vec(Pipe):
|
||||||
tokvecs = self.model.predict(docs)
|
tokvecs = self.model.predict(docs)
|
||||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||||
for listener in self.listeners:
|
for listener in self.listeners:
|
||||||
listener.receive(batch_id, tokvecs, None)
|
listener.receive(batch_id, tokvecs, lambda dX: [])
|
||||||
return tokvecs
|
return tokvecs
|
||||||
|
|
||||||
def set_annotations(self, docs: Sequence[Doc], tokvecses) -> None:
|
def set_annotations(self, docs: Sequence[Doc], tokvecses) -> None:
|
||||||
|
|
|
@ -169,3 +169,22 @@ def test_tok2vec_listener():
|
||||||
nlp.select_pipes(disable="tok2vec")
|
nlp.select_pipes(disable="tok2vec")
|
||||||
assert nlp.pipe_names == ["tagger"]
|
assert nlp.pipe_names == ["tagger"]
|
||||||
nlp("Running the pipeline with the Tok2Vec component disabled.")
|
nlp("Running the pipeline with the Tok2Vec component disabled.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_tok2vec_listener_callback():
|
||||||
|
orig_config = Config().from_str(cfg_string)
|
||||||
|
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
|
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||||
|
tagger = nlp.get_pipe("tagger")
|
||||||
|
tok2vec = nlp.get_pipe("tok2vec")
|
||||||
|
nlp._link_components()
|
||||||
|
docs = [nlp.make_doc("A random sentence")]
|
||||||
|
tok2vec.model.initialize(X=docs)
|
||||||
|
gold_array = [[1.0 for tag in ["V", "Z"]] for word in docs]
|
||||||
|
label_sample = [tagger.model.ops.asarray(gold_array, dtype="float32")]
|
||||||
|
tagger.model.initialize(X=docs, Y=label_sample)
|
||||||
|
docs = [nlp.make_doc("Another entirely random sentence")]
|
||||||
|
tok2vec.predict(docs)
|
||||||
|
Y, get_dX = tagger.model.begin_update(docs)
|
||||||
|
# assure that the backprop call works (and doesn't hit a 'None' callback)
|
||||||
|
assert get_dX(Y) is not None
|
||||||
|
|
Loading…
Reference in New Issue
Block a user