mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Fix issue #6950: allow pickling Tok2Vec with listeners
This commit is contained in:
parent
61b04a70d5
commit
ad9ce3c8f6
|
@ -121,7 +121,7 @@ class Tok2Vec(TrainablePipe):
|
|||
tokvecs = self.model.predict(docs)
|
||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||
for listener in self.listeners:
|
||||
listener.receive(batch_id, tokvecs, lambda dX: [])
|
||||
listener.receive(batch_id, tokvecs, _empty_backprop)
|
||||
return tokvecs
|
||||
|
||||
def set_annotations(self, docs: Sequence[Doc], tokvecses) -> None:
|
||||
|
@ -300,3 +300,7 @@ def forward(model: Tok2VecListener, inputs, is_train: bool):
|
|||
else:
|
||||
outputs = [doc.tensor for doc in inputs]
|
||||
return outputs, lambda dX: []
|
||||
|
||||
|
||||
def _empty_backprop(dX): # for pickling
|
||||
return []
|
||||
|
|
59
spacy/tests/regression/test_issue6950.py
Normal file
59
spacy/tests/regression/test_issue6950.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
from spacy.lang.en import English
|
||||
from spacy.training import Example
|
||||
from spacy.util import load_config_from_str
|
||||
import pickle
|
||||
|
||||
|
||||
CONFIG = """
|
||||
[nlp]
|
||||
lang = "en"
|
||||
pipeline = ["tok2vec", "tagger"]
|
||||
|
||||
[components]
|
||||
|
||||
[components.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
||||
[components.tok2vec.model]
|
||||
@architectures = "spacy.Tok2Vec.v1"
|
||||
|
||||
[components.tok2vec.model.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v1"
|
||||
width = ${components.tok2vec.model.encode:width}
|
||||
attrs = ["NORM","PREFIX","SUFFIX","SHAPE"]
|
||||
rows = [5000,2500,2500,2500]
|
||||
include_static_vectors = false
|
||||
|
||||
[components.tok2vec.model.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||
width = 96
|
||||
depth = 4
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
|
||||
[components.ner]
|
||||
factory = "ner"
|
||||
|
||||
[components.tagger]
|
||||
factory = "tagger"
|
||||
|
||||
[components.tagger.model]
|
||||
@architectures = "spacy.Tagger.v1"
|
||||
nO = null
|
||||
|
||||
[components.tagger.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.encode:width}
|
||||
upstream = "*"
|
||||
"""
|
||||
|
||||
|
||||
def test_issue6950():
|
||||
"""Test that the nlp object with initialized tok2vec with listeners pickles
|
||||
correctly (and doesn't have lambdas).
|
||||
"""
|
||||
nlp = English.from_config(load_config_from_str(CONFIG))
|
||||
nlp.initialize(lambda: [Example.from_dict(nlp.make_doc("hello"), {"tags": ["V"]})])
|
||||
pickle.dumps(nlp)
|
||||
nlp("hello")
|
||||
pickle.dumps(nlp)
|
Loading…
Reference in New Issue
Block a user