From c96786152baef23e8b336523aa23952d86c67b3c Mon Sep 17 00:00:00 2001 From: shademe Date: Wed, 14 Dec 2022 12:47:20 +0100 Subject: [PATCH] `Tok2Vec`: Add `distill` method --- spacy/pipeline/tok2vec.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index c742aaeaa..99ed21a7c 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -219,6 +219,20 @@ class Tok2Vec(TrainablePipe): def add_label(self, label): raise NotImplementedError + def distill( + self, + teacher_pipe: "TrainablePipe", + teacher_docs: Iterable["Doc"], + student_docs: Iterable["Doc"], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None, + ) -> Dict[str, float]: + teacher_pipe.set_annotations(teacher_docs, teacher_pipe.predict(teacher_docs)) + examples = [Example(doc, doc) for doc in student_docs] + return self.update(examples, drop=drop, sgd=sgd, losses=losses) + class Tok2VecListener(Model): """A layer that gets fed its answers from an upstream connection,