mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-22 10:02:01 +03:00
student: do not request use of optimizer in student pipe
We apply finish up the updates once in the training loop instead. Also add the necessary logic to `Language.distill` to mirror `Language.update`.
This commit is contained in:
parent
b6544f50ec
commit
add1a21657
|
@ -1024,7 +1024,7 @@ class Language:
|
|||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
sgd: Union[Optimizer, None, Literal[False]] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
exclude: Iterable[str] = SimpleFrozenList(),
|
||||
|
@ -1037,7 +1037,9 @@ class Language:
|
|||
(teacher) and predicted (student) docs must have the same number of
|
||||
tokens and the same orthography.
|
||||
drop (float): The dropout rate.
|
||||
sgd (Optional[Optimizer]): An optimizer.
|
||||
sgd (Union[Optimizer, None, Literal[False]]): An optimizer. Will
|
||||
be created via create_optimizer if 'None'. No optimizer will
|
||||
be used when set to 'False'.
|
||||
losses (Optional(Dict[str, float])): Dictionary to update with the loss,
|
||||
keyed by component.
|
||||
component_cfg (Optional[Dict[str, Dict[str, Any]]]): Config parameters
|
||||
|
@ -1107,11 +1109,23 @@ class Language:
|
|||
student_proc.distill(
|
||||
teacher_pipe,
|
||||
examples,
|
||||
sgd=sgd,
|
||||
sgd=None,
|
||||
losses=losses,
|
||||
**component_cfg[student_name],
|
||||
)
|
||||
|
||||
# Only finish the update after all component updates are done. Some
|
||||
# components may share weights (such as tok2vec) and we only want
|
||||
# to apply weight updates after all gradients are accumulated.
|
||||
for student_name, student_proc in self.pipeline:
|
||||
if (
|
||||
student_name not in exclude
|
||||
and isinstance(student_proc, ty.DistillableComponent)
|
||||
and student_proc.is_distillable
|
||||
and sgd not in (None, False)
|
||||
):
|
||||
student_proc.finish_update(sgd)
|
||||
|
||||
return losses
|
||||
|
||||
def disable_pipes(self, *names) -> "DisabledPipes":
|
||||
|
|
|
@ -8,10 +8,12 @@ import random
|
|||
import sys
|
||||
import shutil
|
||||
|
||||
|
||||
from .example import Example
|
||||
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
||||
from ..errors import Errors
|
||||
from ..tokens.doc import Doc
|
||||
from .. import ty
|
||||
from ..util import resolve_dot_names, registry, logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -340,20 +342,20 @@ def _distill_loop(
|
|||
subbatch,
|
||||
drop=dropout,
|
||||
losses=losses,
|
||||
sgd=None,
|
||||
sgd=False,
|
||||
exclude=exclude,
|
||||
annotates=annotating_components,
|
||||
student_to_teacher=student_to_teacher,
|
||||
)
|
||||
# TODO: refactor this so we don't have to run it separately in here
|
||||
for name, proc in student.pipeline:
|
||||
for student_name, student_proc in student.pipeline:
|
||||
if (
|
||||
name not in exclude
|
||||
and hasattr(proc, "is_trainable")
|
||||
and proc.is_trainable
|
||||
and proc.model not in (True, False, None) # type: ignore[attr-defined]
|
||||
student_name not in exclude
|
||||
and isinstance(student_proc, ty.DistillableComponent)
|
||||
and student_proc.is_distillable
|
||||
and student_proc.model not in (False, None) # type: ignore[attr-defined]
|
||||
):
|
||||
proc.finish_update(optimizer) # type: ignore[attr-defined]
|
||||
student_proc.finish_update(optimizer) # type: ignore[attr-defined]
|
||||
optimizer.step_schedules()
|
||||
if not (step % eval_frequency):
|
||||
if optimizer.averages:
|
||||
|
|
Loading…
Reference in New Issue
Block a user