diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index e423d9a19..81d0097ed 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -540,3 +540,70 @@ def test_tok2vec_listeners_textcat(): assert cats1["imperative"] < 0.9 assert [t.tag_ for t in docs[0]] == ["V", "J", "N"] assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"] + + +def test_tok2vec_distill(): + orig_config = Config().from_str(cfg_string_multi_textcat) + teacher_nlp = util.load_model_from_config( + orig_config, auto_fill=True, validate=True + ) + student_nlp = util.load_model_from_config( + orig_config, auto_fill=True, validate=True + ) + + # Remove pipes that don't currently support distillation. + teacher_nlp.remove_pipe("textcat_multilabel") + student_nlp.remove_pipe("textcat_multilabel") + + train_examples_teacher = [] + train_examples_student = [] + for t in TRAIN_DATA: + train_examples_teacher.append( + Example.from_dict(teacher_nlp.make_doc(t[0]), t[1]) + ) + train_examples_student.append( + Example.from_dict(student_nlp.make_doc(t[0]), t[1]) + ) + + optimizer = teacher_nlp.initialize(lambda: train_examples_teacher) + for i in range(50): + losses = {} + teacher_nlp.update(train_examples_teacher, sgd=optimizer, losses=losses) + + student_nlp.initialize(lambda: train_examples_student) + student_tagger = student_nlp.get_pipe("tagger") + + tagger_tok2vec = student_tagger.model.get_ref("tok2vec") + tagger_tok2vec_forward = tagger_tok2vec._func + + def mock_listener_forward(model: Tok2VecListener, inputs, is_train: bool): + model.attrs["last_input"] = inputs + return tagger_tok2vec_forward(model, inputs, is_train) + + tagger_tok2vec._func = mock_listener_forward + + # Since Language.distill creates a copy of the student docs to use as + # its internal teacher docs, we'll need to monkey-patch the tok2vec pipe's + # distill method. + student_tok2vec = student_nlp.get_pipe("tok2vec") + student_tok2vec._old_distill = student_tok2vec.distill + + def tok2vec_distill_wrapper( + self, + teacher_pipe, + teacher_docs, + student_docs, + **kwargs, + ): + assert all(not doc.tensor.any() for doc in teacher_docs) + out = self._old_distill(teacher_pipe, teacher_docs, student_docs, **kwargs) + assert all(doc.tensor.any() for doc in teacher_docs) + return out + + student_tok2vec.distill = tok2vec_distill_wrapper.__get__(student_tok2vec, Tok2Vec) + + student_docs = [eg.predicted for eg in train_examples_student] + student_nlp.distill( + teacher_nlp, student_docs, sgd=optimizer, losses=losses, pipe_map={} + ) + assert tagger_tok2vec.attrs["last_input"] == student_docs