Language.update: call pipe finish_update after all pipe updates

This does correct and fast updates if multiple components update the
same parameters.
This commit is contained in:
Daniël de Kok 2023-02-01 20:15:07 +01:00
parent 7753b88e0d
commit 6357d62ca2

View File

@ -1155,7 +1155,7 @@ class Language:
and isinstance(proc, ty.TrainableComponent)
and proc.is_trainable
):
proc.update(examples, sgd=sgd, losses=losses, **component_cfg[name])
proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
if name in annotates:
for doc, eg in zip(
_pipe(
@ -1168,6 +1168,14 @@ class Language:
examples,
):
eg.predicted = doc
for name, proc in self.pipeline:
if (
name not in exclude
and isinstance(proc, ty.TrainableComponent)
and proc.is_trainable
):
proc.finish_update(sgd)
return losses
def rehearse(