diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index d82979752..9d7265fbb 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -542,8 +542,48 @@ def test_tok2vec_listeners_textcat(): assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"] -def test_tok2vec_distill(): - orig_config = Config().from_str(cfg_string_multi_textcat) +cfg_string_distillation = """ + [nlp] + lang = "en" + pipeline = ["tok2vec","tagger"] + + [components] + + [components.tagger] + factory = "tagger" + + [components.tagger.model] + @architectures = "spacy.Tagger.v2" + nO = null + + [components.tagger.model.tok2vec] + @architectures = "spacy.Tok2VecListener.v1" + width = ${components.tok2vec.model.encode.width} + + [components.tok2vec] + factory = "tok2vec" + + [components.tok2vec.model] + @architectures = "spacy.Tok2Vec.v2" + + [components.tok2vec.model.embed] + @architectures = "spacy.MultiHashEmbed.v2" + width = ${components.tok2vec.model.encode.width} + rows = [2000, 1000, 1000, 1000] + attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"] + include_static_vectors = false + + [components.tok2vec.model.encode] + @architectures = "spacy.MaxoutWindowEncoder.v2" + width = 96 + depth = 4 + window_size = 1 + maxout_pieces = 3 + """ + + +def test_tok2vec_distillation_teacher_annotations(): + orig_config = Config().from_str(cfg_string_distillation) teacher_nlp = util.load_model_from_config( orig_config, auto_fill=True, validate=True ) @@ -551,10 +591,6 @@ def test_tok2vec_distill(): orig_config, auto_fill=True, validate=True ) - # Remove pipes that don't currently support distillation. - teacher_nlp.remove_pipe("textcat_multilabel") - student_nlp.remove_pipe("textcat_multilabel") - train_examples_teacher = [] train_examples_student = [] for t in TRAIN_DATA: @@ -571,39 +607,25 @@ def test_tok2vec_distill(): teacher_nlp.update(train_examples_teacher, sgd=optimizer, losses=losses) student_nlp.initialize(lambda: train_examples_student) - student_tagger = student_nlp.get_pipe("tagger") - tagger_tok2vec = student_tagger.model.get_ref("tok2vec") - tagger_tok2vec_forward = tagger_tok2vec._func - - def mock_listener_forward(model: Tok2VecListener, inputs, is_train: bool): - model.attrs["last_input"] = inputs - return tagger_tok2vec_forward(model, inputs, is_train) - - tagger_tok2vec._func = mock_listener_forward - - # Since Language.distill creates a copy of the student docs to use as - # its internal teacher docs, we'll need to monkey-patch the tok2vec pipe's - # distill method. + # Since Language.distill creates a copy of the examples to use as + # its internal teacher/student docs, we'll need to monkey-patch the + # tok2vec pipe's distill method. student_tok2vec = student_nlp.get_pipe("tok2vec") student_tok2vec._old_distill = student_tok2vec.distill def tok2vec_distill_wrapper( self, teacher_pipe, - teacher_docs, - student_docs, + examples, **kwargs, ): - assert all(not doc.tensor.any() for doc in teacher_docs) - out = self._old_distill(teacher_pipe, teacher_docs, student_docs, **kwargs) - assert all(doc.tensor.any() for doc in teacher_docs) + assert all(not eg.reference.tensor.any() for eg in examples) + out = self._old_distill(teacher_pipe, examples, **kwargs) + assert all(eg.reference.tensor.any() for eg in examples) return out student_tok2vec.distill = tok2vec_distill_wrapper.__get__(student_tok2vec, Tok2Vec) - - student_docs = [eg.predicted for eg in train_examples_student] student_nlp.distill( - teacher_nlp, student_docs, sgd=optimizer, losses=losses, pipe_map={} + teacher_nlp, train_examples_student, sgd=optimizer, losses=losses ) - assert tagger_tok2vec.attrs["last_input"] == student_docs