mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
student: do not request use of optimizer in student pipe
We apply finish up the updates once in the training loop instead. Also add the necessary logic to `Language.distill` to mirror `Language.update`.
This commit is contained in:
parent
b6544f50ec
commit
add1a21657
|
@ -1024,7 +1024,7 @@ class Language:
|
||||||
examples: Iterable[Example],
|
examples: Iterable[Example],
|
||||||
*,
|
*,
|
||||||
drop: float = 0.0,
|
drop: float = 0.0,
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Union[Optimizer, None, Literal[False]] = None,
|
||||||
losses: Optional[Dict[str, float]] = None,
|
losses: Optional[Dict[str, float]] = None,
|
||||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
exclude: Iterable[str] = SimpleFrozenList(),
|
exclude: Iterable[str] = SimpleFrozenList(),
|
||||||
|
@ -1037,7 +1037,9 @@ class Language:
|
||||||
(teacher) and predicted (student) docs must have the same number of
|
(teacher) and predicted (student) docs must have the same number of
|
||||||
tokens and the same orthography.
|
tokens and the same orthography.
|
||||||
drop (float): The dropout rate.
|
drop (float): The dropout rate.
|
||||||
sgd (Optional[Optimizer]): An optimizer.
|
sgd (Union[Optimizer, None, Literal[False]]): An optimizer. Will
|
||||||
|
be created via create_optimizer if 'None'. No optimizer will
|
||||||
|
be used when set to 'False'.
|
||||||
losses (Optional(Dict[str, float])): Dictionary to update with the loss,
|
losses (Optional(Dict[str, float])): Dictionary to update with the loss,
|
||||||
keyed by component.
|
keyed by component.
|
||||||
component_cfg (Optional[Dict[str, Dict[str, Any]]]): Config parameters
|
component_cfg (Optional[Dict[str, Dict[str, Any]]]): Config parameters
|
||||||
|
@ -1107,11 +1109,23 @@ class Language:
|
||||||
student_proc.distill(
|
student_proc.distill(
|
||||||
teacher_pipe,
|
teacher_pipe,
|
||||||
examples,
|
examples,
|
||||||
sgd=sgd,
|
sgd=None,
|
||||||
losses=losses,
|
losses=losses,
|
||||||
**component_cfg[student_name],
|
**component_cfg[student_name],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Only finish the update after all component updates are done. Some
|
||||||
|
# components may share weights (such as tok2vec) and we only want
|
||||||
|
# to apply weight updates after all gradients are accumulated.
|
||||||
|
for student_name, student_proc in self.pipeline:
|
||||||
|
if (
|
||||||
|
student_name not in exclude
|
||||||
|
and isinstance(student_proc, ty.DistillableComponent)
|
||||||
|
and student_proc.is_distillable
|
||||||
|
and sgd not in (None, False)
|
||||||
|
):
|
||||||
|
student_proc.finish_update(sgd)
|
||||||
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def disable_pipes(self, *names) -> "DisabledPipes":
|
def disable_pipes(self, *names) -> "DisabledPipes":
|
||||||
|
|
|
@ -8,10 +8,12 @@ import random
|
||||||
import sys
|
import sys
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
from .example import Example
|
from .example import Example
|
||||||
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..tokens.doc import Doc
|
from ..tokens.doc import Doc
|
||||||
|
from .. import ty
|
||||||
from ..util import resolve_dot_names, registry, logger
|
from ..util import resolve_dot_names, registry, logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -340,20 +342,20 @@ def _distill_loop(
|
||||||
subbatch,
|
subbatch,
|
||||||
drop=dropout,
|
drop=dropout,
|
||||||
losses=losses,
|
losses=losses,
|
||||||
sgd=None,
|
sgd=False,
|
||||||
exclude=exclude,
|
exclude=exclude,
|
||||||
annotates=annotating_components,
|
annotates=annotating_components,
|
||||||
student_to_teacher=student_to_teacher,
|
student_to_teacher=student_to_teacher,
|
||||||
)
|
)
|
||||||
# TODO: refactor this so we don't have to run it separately in here
|
# TODO: refactor this so we don't have to run it separately in here
|
||||||
for name, proc in student.pipeline:
|
for student_name, student_proc in student.pipeline:
|
||||||
if (
|
if (
|
||||||
name not in exclude
|
student_name not in exclude
|
||||||
and hasattr(proc, "is_trainable")
|
and isinstance(student_proc, ty.DistillableComponent)
|
||||||
and proc.is_trainable
|
and student_proc.is_distillable
|
||||||
and proc.model not in (True, False, None) # type: ignore[attr-defined]
|
and student_proc.model not in (False, None) # type: ignore[attr-defined]
|
||||||
):
|
):
|
||||||
proc.finish_update(optimizer) # type: ignore[attr-defined]
|
student_proc.finish_update(optimizer) # type: ignore[attr-defined]
|
||||||
optimizer.step_schedules()
|
optimizer.step_schedules()
|
||||||
if not (step % eval_frequency):
|
if not (step % eval_frequency):
|
||||||
if optimizer.averages:
|
if optimizer.averages:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user