Tok2Vec: Add distill method

This commit is contained in:
shademe 2022-12-14 12:47:20 +01:00
parent 5e297aa20e
commit c96786152b
No known key found for this signature in database
GPG Key ID: 6FCA9FC635B2A402

View File

@ -219,6 +219,20 @@ class Tok2Vec(TrainablePipe):
def add_label(self, label):
raise NotImplementedError
def distill(
self,
teacher_pipe: "TrainablePipe",
teacher_docs: Iterable["Doc"],
student_docs: Iterable["Doc"],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
teacher_pipe.set_annotations(teacher_docs, teacher_pipe.predict(teacher_docs))
examples = [Example(doc, doc) for doc in student_docs]
return self.update(examples, drop=drop, sgd=sgd, losses=losses)
class Tok2VecListener(Model):
"""A layer that gets fed its answers from an upstream connection,