student: do not request use of optimizer in student pipe

We apply finish up the updates once in the training loop instead.

Also add the necessary logic to `Language.distill` to mirror
`Language.update`.
This commit is contained in:
Daniël de Kok 2023-04-19 17:48:35 +02:00
parent b6544f50ec
commit add1a21657
2 changed files with 26 additions and 10 deletions

View File

@ -1024,7 +1024,7 @@ class Language:
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
sgd: Union[Optimizer, None, Literal[False]] = None,
losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = SimpleFrozenList(),
@ -1037,7 +1037,9 @@ class Language:
(teacher) and predicted (student) docs must have the same number of
tokens and the same orthography.
drop (float): The dropout rate.
sgd (Optional[Optimizer]): An optimizer.
sgd (Union[Optimizer, None, Literal[False]]): An optimizer. Will
be created via create_optimizer if 'None'. No optimizer will
be used when set to 'False'.
losses (Optional(Dict[str, float])): Dictionary to update with the loss,
keyed by component.
component_cfg (Optional[Dict[str, Dict[str, Any]]]): Config parameters
@ -1107,11 +1109,23 @@ class Language:
student_proc.distill(
teacher_pipe,
examples,
sgd=sgd,
sgd=None,
losses=losses,
**component_cfg[student_name],
)
# Only finish the update after all component updates are done. Some
# components may share weights (such as tok2vec) and we only want
# to apply weight updates after all gradients are accumulated.
for student_name, student_proc in self.pipeline:
if (
student_name not in exclude
and isinstance(student_proc, ty.DistillableComponent)
and student_proc.is_distillable
and sgd not in (None, False)
):
student_proc.finish_update(sgd)
return losses
def disable_pipes(self, *names) -> "DisabledPipes":

View File

@ -8,10 +8,12 @@ import random
import sys
import shutil
from .example import Example
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
from ..errors import Errors
from ..tokens.doc import Doc
from .. import ty
from ..util import resolve_dot_names, registry, logger
if TYPE_CHECKING:
@ -340,20 +342,20 @@ def _distill_loop(
subbatch,
drop=dropout,
losses=losses,
sgd=None,
sgd=False,
exclude=exclude,
annotates=annotating_components,
student_to_teacher=student_to_teacher,
)
# TODO: refactor this so we don't have to run it separately in here
for name, proc in student.pipeline:
for student_name, student_proc in student.pipeline:
if (
name not in exclude
and hasattr(proc, "is_trainable")
and proc.is_trainable
and proc.model not in (True, False, None) # type: ignore[attr-defined]
student_name not in exclude
and isinstance(student_proc, ty.DistillableComponent)
and student_proc.is_distillable
and student_proc.model not in (False, None) # type: ignore[attr-defined]
):
proc.finish_update(optimizer) # type: ignore[attr-defined]
student_proc.finish_update(optimizer) # type: ignore[attr-defined]
optimizer.step_schedules()
if not (step % eval_frequency):
if optimizer.averages: