mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-24 19:11:58 +03:00
Tok2Vec
: Add distill
method
This commit is contained in:
parent
5e297aa20e
commit
c96786152b
|
@ -219,6 +219,20 @@ class Tok2Vec(TrainablePipe):
|
|||
def add_label(self, label):
|
||||
raise NotImplementedError
|
||||
|
||||
def distill(
|
||||
self,
|
||||
teacher_pipe: "TrainablePipe",
|
||||
teacher_docs: Iterable["Doc"],
|
||||
student_docs: Iterable["Doc"],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
teacher_pipe.set_annotations(teacher_docs, teacher_pipe.predict(teacher_docs))
|
||||
examples = [Example(doc, doc) for doc in student_docs]
|
||||
return self.update(examples, drop=drop, sgd=sgd, losses=losses)
|
||||
|
||||
|
||||
class Tok2VecListener(Model):
|
||||
"""A layer that gets fed its answers from an upstream connection,
|
||||
|
|
Loading…
Reference in New Issue
Block a user