Update distill signature to accept Examples instead of separate teacher and student docs

This commit is contained in:
shademe 2023-01-16 17:08:25 +01:00
parent 38929c2ca9
commit 553181b3ca
No known key found for this signature in database
GPG Key ID: 6FCA9FC635B2A402

View File

@ -192,14 +192,37 @@ class Tok2Vec(TrainablePipe):
def distill(
self,
teacher_pipe: "TrainablePipe",
teacher_docs: Iterable["Doc"],
student_docs: Iterable["Doc"],
teacher_pipe: Optional["TrainablePipe"],
examples: Iterable["Example"],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""Train a pipe (the student) on the predictions of another pipe
(the teacher). The student is typically trained on the probability
distribution of the teacher, but details may differ per pipe.
teacher_pipe (Optional[TrainablePipe]): The teacher pipe to learn
from.
examples (Iterable[Example]): Distillation examples. The reference
and predicted docs must have the same number of tokens and the
same orthography.
drop (float): dropout rate.
sgd (Optional[Optimizer]): An optimizer. Will be created via
create_optimizer if not set.
losses (Optional[Dict[str, float]]): Optional record of loss during
distillation.
RETURNS: The updated losses dictionary.
DOCS: https://spacy.io/api/tok2vec#distill
"""
# By default we require a teacher pipe, but there are downstream
# implementations that don't require a pipe.
if teacher_pipe is None:
raise ValueError(Errors.E4002.format(name=self.name))
teacher_docs = [eg.reference for eg in examples]
student_docs = [eg.predicted for eg in examples]
teacher_preds = teacher_pipe.predict(teacher_docs)
teacher_pipe.set_annotations(teacher_docs, teacher_preds)
return self._update_with_docs(student_docs, drop=drop, sgd=sgd, losses=losses)