diff --git a/spacy/language.py b/spacy/language.py index 9cb8d938f..a0f2a990f 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1024,7 +1024,7 @@ class Language: examples: Iterable[Example], *, drop: float = 0.0, - sgd: Optional[Optimizer] = None, + sgd: Union[Optimizer, None, Literal[False]] = None, losses: Optional[Dict[str, float]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, exclude: Iterable[str] = SimpleFrozenList(), @@ -1037,7 +1037,9 @@ class Language: (teacher) and predicted (student) docs must have the same number of tokens and the same orthography. drop (float): The dropout rate. - sgd (Optional[Optimizer]): An optimizer. + sgd (Union[Optimizer, None, Literal[False]]): An optimizer. Will + be created via create_optimizer if 'None'. No optimizer will + be used when set to 'False'. losses (Optional(Dict[str, float])): Dictionary to update with the loss, keyed by component. component_cfg (Optional[Dict[str, Dict[str, Any]]]): Config parameters @@ -1107,11 +1109,23 @@ class Language: student_proc.distill( teacher_pipe, examples, - sgd=sgd, + sgd=None, losses=losses, **component_cfg[student_name], ) + # Only finish the update after all component updates are done. Some + # components may share weights (such as tok2vec) and we only want + # to apply weight updates after all gradients are accumulated. + for student_name, student_proc in self.pipeline: + if ( + student_name not in exclude + and isinstance(student_proc, ty.DistillableComponent) + and student_proc.is_distillable + and sgd not in (None, False) + ): + student_proc.finish_update(sgd) + return losses def disable_pipes(self, *names) -> "DisabledPipes": diff --git a/spacy/training/loop.py b/spacy/training/loop.py index 8a0c9495f..7f67aa2cf 100644 --- a/spacy/training/loop.py +++ b/spacy/training/loop.py @@ -8,10 +8,12 @@ import random import sys import shutil + from .example import Example from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining from ..errors import Errors from ..tokens.doc import Doc +from .. import ty from ..util import resolve_dot_names, registry, logger if TYPE_CHECKING: @@ -340,20 +342,20 @@ def _distill_loop( subbatch, drop=dropout, losses=losses, - sgd=None, + sgd=False, exclude=exclude, annotates=annotating_components, student_to_teacher=student_to_teacher, ) # TODO: refactor this so we don't have to run it separately in here - for name, proc in student.pipeline: + for student_name, student_proc in student.pipeline: if ( - name not in exclude - and hasattr(proc, "is_trainable") - and proc.is_trainable - and proc.model not in (True, False, None) # type: ignore[attr-defined] + student_name not in exclude + and isinstance(student_proc, ty.DistillableComponent) + and student_proc.is_distillable + and student_proc.model not in (False, None) # type: ignore[attr-defined] ): - proc.finish_update(optimizer) # type: ignore[attr-defined] + student_proc.finish_update(optimizer) # type: ignore[attr-defined] optimizer.step_schedules() if not (step % eval_frequency): if optimizer.averages: