diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index c742aaeaa..99ed21a7c 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -219,6 +219,20 @@ class Tok2Vec(TrainablePipe): def add_label(self, label): raise NotImplementedError + def distill( + self, + teacher_pipe: "TrainablePipe", + teacher_docs: Iterable["Doc"], + student_docs: Iterable["Doc"], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None, + ) -> Dict[str, float]: + teacher_pipe.set_annotations(teacher_docs, teacher_pipe.predict(teacher_docs)) + examples = [Example(doc, doc) for doc in student_docs] + return self.update(examples, drop=drop, sgd=sgd, losses=losses) + class Tok2VecListener(Model): """A layer that gets fed its answers from an upstream connection,