name -> student_name

This commit is contained in:
Daniël de Kok 2023-01-30 10:10:22 +01:00
parent 81ccd66804
commit f6312eabba

View File

@ -1070,44 +1070,46 @@ class Language:
if component_cfg is None:
component_cfg = {}
pipe_kwargs = {}
for name, student_proc in self.pipeline:
component_cfg.setdefault(name, {})
pipe_kwargs[name] = deepcopy(component_cfg[name])
component_cfg[name].setdefault("drop", drop)
pipe_kwargs[name].setdefault("batch_size", self.batch_size)
for student_name, student_proc in self.pipeline:
component_cfg.setdefault(student_name, {})
pipe_kwargs[student_name] = deepcopy(component_cfg[student_name])
component_cfg[student_name].setdefault("drop", drop)
pipe_kwargs[student_name].setdefault("batch_size", self.batch_size)
teacher_pipes = dict(teacher.pipeline)
for name, student_proc in self.pipeline:
if name in annotates:
for student_name, student_proc in self.pipeline:
if student_name in annotates:
for doc, eg in zip(
_pipe(
(eg.predicted for eg in examples),
proc=student_proc,
name=name,
name=student_name,
default_error_handler=self.default_error_handler,
kwargs=pipe_kwargs[name],
kwargs=pipe_kwargs[student_name],
),
examples,
):
eg.predicted = doc
if (
name not in exclude
student_name not in exclude
and isinstance(student_proc, ty.DistillableComponent)
and student_proc.is_distillable
):
# A missing teacher pipe is not an error, some student pipes
# do not need a teacher, such as tok2vec layer losses.
teacher_pipe_name = (
student_to_teacher[name] if name in student_to_teacher else name
teacher_name = (
student_to_teacher[student_name]
if student_name in student_to_teacher
else student_name
)
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
teacher_pipe = teacher_pipes.get(teacher_name, None)
student_proc.distill(
teacher_pipe,
examples,
sgd=sgd,
losses=losses,
**component_cfg[name],
**component_cfg[student_name],
)
return losses