From 553181b3ca178805f817ba8e6b1a518ab508768d Mon Sep 17 00:00:00 2001 From: shademe Date: Mon, 16 Jan 2023 17:08:25 +0100 Subject: [PATCH] Update `distill` signature to accept `Example`s instead of separate teacher and student docs --- spacy/pipeline/tok2vec.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index fe768fb06..aa6195ab6 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -192,14 +192,37 @@ class Tok2Vec(TrainablePipe): def distill( self, - teacher_pipe: "TrainablePipe", - teacher_docs: Iterable["Doc"], - student_docs: Iterable["Doc"], + teacher_pipe: Optional["TrainablePipe"], + examples: Iterable["Example"], *, drop: float = 0.0, sgd: Optional[Optimizer] = None, losses: Optional[Dict[str, float]] = None, ) -> Dict[str, float]: + """Train a pipe (the student) on the predictions of another pipe + (the teacher). The student is typically trained on the probability + distribution of the teacher, but details may differ per pipe. + + teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn + from. + examples (Iterable[Example]): Distillation examples. The reference + and predicted docs must have the same number of tokens and the + same orthography. + drop (float): dropout rate. + sgd (Optional[Optimizer]): An optimizer. Will be created via + create_optimizer if not set. + losses (Optional[Dict[str, float]]): Optional record of loss during + distillation. + RETURNS: The updated losses dictionary. + + DOCS: https://spacy.io/api/tok2vec#distill + """ + # By default we require a teacher pipe, but there are downstream + # implementations that don't require a pipe. + if teacher_pipe is None: + raise ValueError(Errors.E4002.format(name=self.name)) + teacher_docs = [eg.reference for eg in examples] + student_docs = [eg.predicted for eg in examples] teacher_preds = teacher_pipe.predict(teacher_docs) teacher_pipe.set_annotations(teacher_docs, teacher_preds) return self._update_with_docs(student_docs, drop=drop, sgd=sgd, losses=losses)