diff --git a/spacy/pipeline/trainable_pipe.pyx b/spacy/pipeline/trainable_pipe.pyx index 875a55448..42e612c8e 100644 --- a/spacy/pipeline/trainable_pipe.pyx +++ b/spacy/pipeline/trainable_pipe.pyx @@ -59,8 +59,7 @@ cdef class TrainablePipe(Pipe): def distill(self, teacher_pipe: Optional["TrainablePipe"], - teacher_docs: Iterable["Doc"], - student_docs: Iterable["Doc"], + examples: Iterable["Example"], *, drop: float=0.0, sgd: Optional[Optimizer]=None, @@ -71,10 +70,8 @@ cdef class TrainablePipe(Pipe): teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn from. - teacher_docs (Iterable[Doc]): Documents passed through teacher pipes. - student_docs (Iterable[Doc]): Documents passed through student pipes. - Must contain the same tokens as `teacher_docs` but may have - different annotations. + examples (Iterable[Example]): Distillation examples. The reference + must contain teacher annotations (if any). drop (float): dropout rate. sgd (Optional[Optimizer]): An optimizer. Will be created via create_optimizer if not set. @@ -89,16 +86,13 @@ cdef class TrainablePipe(Pipe): if losses is None: losses = {} losses.setdefault(self.name, 0.0) - if not any(len(doc) for doc in teacher_docs): - return losses - if not any(len(doc) for doc in student_docs): - return losses + validate_examples(examples, "TrainablePipe.distill") set_dropout_rate(self.model, drop) for node in teacher_pipe.model.walk(): if node.name == "softmax": node.attrs["softmax_normalize"] = True - teacher_scores = teacher_pipe.model.predict(teacher_docs) - student_scores, bp_student_scores = self.model.begin_update(student_docs) + teacher_scores = teacher_pipe.model.predict([eg.reference for eg in examples]) + student_scores, bp_student_scores = self.model.begin_update([eg.predicted for eg in examples]) loss, d_scores = self.get_teacher_student_loss(teacher_scores, student_scores) bp_student_scores(d_scores) if sgd is not None: diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index e6a2bfcf0..02696739e 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -209,8 +209,7 @@ cdef class Parser(TrainablePipe): def distill(self, teacher_pipe: Optional[TrainablePipe], - teacher_docs: Iterable[Doc], - student_docs: Iterable[Doc], + examples: Iterable["Example"], *, drop: float=0.0, sgd: Optional[Optimizer]=None, @@ -221,11 +220,8 @@ cdef class Parser(TrainablePipe): teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn from. - teacher_docs (Iterable[Doc]): Documents passed through teacher pipes. - student_docs (Iterable[Doc]): Documents passed through student pipes. - Must contain the same tokens as `teacher_docs` but may have - different annotations. - drop (float): dropout rate. + examples (Iterable[Example]): Distillation examples. The reference + must contain teacher annotations (if any). sgd (Optional[Optimizer]): An optimizer. Will be created via create_optimizer if not set. losses (Optional[Dict[str, float]]): Optional record of loss during @@ -238,14 +234,13 @@ cdef class Parser(TrainablePipe): losses = {} losses.setdefault(self.name, 0.0) - if not any(len(doc) for doc in teacher_docs): - return losses - if not any(len(doc) for doc in student_docs): - return losses + validate_examples(examples, "TransitionParser.distill") set_dropout_rate(self.model, drop) - teacher_step_model = teacher_pipe.model.predict(teacher_docs) + student_docs = [eg.predicted for eg in examples] + + teacher_step_model = teacher_pipe.model.predict([eg.reference for eg in examples]) student_step_model, backprop_tok2vec = self.model.begin_update(student_docs) # Add softmax activation, so that we can compute student losses diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index 082b424b8..9429fc746 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -639,11 +639,13 @@ def test_distill(): get_examples=lambda: train_examples, labels=teacher_ner.label_data ) - docs = [eg.predicted for eg in train_examples] + distill_examples = [ + Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA + ] for i in range(100): losses = {} - student_ner.distill(teacher_ner, docs, docs, sgd=optimizer, losses=losses) + student_ner.distill(teacher_ner, distill_examples, sgd=optimizer, losses=losses) assert losses["ner"] < 0.0001 # test the trained model diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index 79b0d6c5e..089c4d066 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -418,11 +418,15 @@ def test_distill(): get_examples=lambda: train_examples, labels=teacher_parser.label_data ) - docs = [eg.predicted for eg in train_examples] + distill_examples = [ + Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA + ] for i in range(200): losses = {} - student_parser.distill(teacher_parser, docs, docs, sgd=optimizer, losses=losses) + student_parser.distill( + teacher_parser, distill_examples, sgd=optimizer, losses=losses + ) assert losses["parser"] < 0.0001 test_text = "I like securities." diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py index 99bd06dce..96c83a335 100644 --- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py +++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py @@ -217,12 +217,14 @@ def test_distill(): get_examples=lambda: train_examples, labels=teacher_lemmatizer.label_data ) - docs = [eg.predicted for eg in train_examples] + distill_examples = [ + Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA + ] for i in range(50): losses = {} student_lemmatizer.distill( - teacher_lemmatizer, docs, docs, sgd=optimizer, losses=losses + teacher_lemmatizer, distill_examples, sgd=optimizer, losses=losses ) assert losses["trainable_lemmatizer"] < 0.00001 diff --git a/spacy/tests/pipeline/test_tagger.py b/spacy/tests/pipeline/test_tagger.py index 8b5226053..b2fd74142 100644 --- a/spacy/tests/pipeline/test_tagger.py +++ b/spacy/tests/pipeline/test_tagger.py @@ -234,11 +234,15 @@ def test_distill(): get_examples=lambda: train_examples, labels=teacher_tagger.label_data ) - docs = [eg.predicted for eg in train_examples] + distill_examples = [ + Example.from_dict(teacher.make_doc(t[0]), {}) for t in TRAIN_DATA + ] for i in range(50): losses = {} - student_tagger.distill(teacher_tagger, docs, docs, sgd=optimizer, losses=losses) + student_tagger.distill( + teacher_tagger, distill_examples, sgd=optimizer, losses=losses + ) assert losses["tagger"] < 0.00001 test_text = "I like blue eggs"