name -> student_name

This commit is contained in:
Daniël de Kok 2023-01-30 10:10:22 +01:00
parent 81ccd66804
commit f6312eabba

View File

@ -1070,44 +1070,46 @@ class Language:
if component_cfg is None: if component_cfg is None:
component_cfg = {} component_cfg = {}
pipe_kwargs = {} pipe_kwargs = {}
for name, student_proc in self.pipeline: for student_name, student_proc in self.pipeline:
component_cfg.setdefault(name, {}) component_cfg.setdefault(student_name, {})
pipe_kwargs[name] = deepcopy(component_cfg[name]) pipe_kwargs[student_name] = deepcopy(component_cfg[student_name])
component_cfg[name].setdefault("drop", drop) component_cfg[student_name].setdefault("drop", drop)
pipe_kwargs[name].setdefault("batch_size", self.batch_size) pipe_kwargs[student_name].setdefault("batch_size", self.batch_size)
teacher_pipes = dict(teacher.pipeline) teacher_pipes = dict(teacher.pipeline)
for name, student_proc in self.pipeline: for student_name, student_proc in self.pipeline:
if name in annotates: if student_name in annotates:
for doc, eg in zip( for doc, eg in zip(
_pipe( _pipe(
(eg.predicted for eg in examples), (eg.predicted for eg in examples),
proc=student_proc, proc=student_proc,
name=name, name=student_name,
default_error_handler=self.default_error_handler, default_error_handler=self.default_error_handler,
kwargs=pipe_kwargs[name], kwargs=pipe_kwargs[student_name],
), ),
examples, examples,
): ):
eg.predicted = doc eg.predicted = doc
if ( if (
name not in exclude student_name not in exclude
and isinstance(student_proc, ty.DistillableComponent) and isinstance(student_proc, ty.DistillableComponent)
and student_proc.is_distillable and student_proc.is_distillable
): ):
# A missing teacher pipe is not an error, some student pipes # A missing teacher pipe is not an error, some student pipes
# do not need a teacher, such as tok2vec layer losses. # do not need a teacher, such as tok2vec layer losses.
teacher_pipe_name = ( teacher_name = (
student_to_teacher[name] if name in student_to_teacher else name student_to_teacher[student_name]
if student_name in student_to_teacher
else student_name
) )
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None) teacher_pipe = teacher_pipes.get(teacher_name, None)
student_proc.distill( student_proc.distill(
teacher_pipe, teacher_pipe,
examples, examples,
sgd=sgd, sgd=sgd,
losses=losses, losses=losses,
**component_cfg[name], **component_cfg[student_name],
) )
return losses return losses