Add training.before_update callback (#11739)

* Add `training.before_update` callback

This callback can be used to implement training paradigms like gradual (un)freezing of components (e.g: the Transformer) after a certain number of training steps to mitigate catastrophic forgetting during fine-tuning.

* Fix type annotation, default config value

* Generalize arguments passed to the callback

* Update schema

* Pass `epoch` to callback, rename `current_step` to `step`

* Add test

* Simplify test

* Replace config string with `spacy.blank`

* Apply suggestions from code review

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Cleanup imports

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Madeesh Kannan 2022-11-23 17:54:58 +01:00 committed by GitHub
parent 8271cfb4cd
commit 5ea14af32b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 49 additions and 1 deletions

View File

@ -90,6 +90,8 @@ dev_corpus = "corpora.dev"
train_corpus = "corpora.train" train_corpus = "corpora.train"
# Optional callback before nlp object is saved to disk after training # Optional callback before nlp object is saved to disk after training
before_to_disk = null before_to_disk = null
# Optional callback that is invoked at the start of each training step
before_update = null
[training.logger] [training.logger]
@loggers = "spacy.ConsoleLogger.v1" @loggers = "spacy.ConsoleLogger.v1"

View File

@ -329,6 +329,7 @@ class ConfigSchemaTraining(BaseModel):
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training") frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training") annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk") before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step")
# fmt: on # fmt: on
class Config: class Config:

View File

@ -2,6 +2,7 @@ import random
import numpy import numpy
import pytest import pytest
import spacy
import srsly import srsly
from spacy.lang.en import English from spacy.lang.en import English
from spacy.tokens import Doc, DocBin from spacy.tokens import Doc, DocBin
@ -11,9 +12,10 @@ from spacy.training import offsets_to_biluo_tags
from spacy.training.alignment_array import AlignmentArray from spacy.training.alignment_array import AlignmentArray
from spacy.training.align import get_alignments from spacy.training.align import get_alignments
from spacy.training.converters import json_to_docs from spacy.training.converters import json_to_docs
from spacy.training.loop import train_while_improving
from spacy.util import get_words_and_spaces, load_model_from_path, minibatch from spacy.util import get_words_and_spaces, load_model_from_path, minibatch
from spacy.util import load_config_from_str from spacy.util import load_config_from_str
from thinc.api import compounding from thinc.api import compounding, Adam
from ..util import make_tempdir from ..util import make_tempdir
@ -1112,3 +1114,39 @@ def test_retokenized_docs(doc):
retokenizer.merge(doc1[0:2]) retokenizer.merge(doc1[0:2])
retokenizer.merge(doc1[5:7]) retokenizer.merge(doc1[5:7])
assert example.get_aligned("ORTH", as_string=True) == expected2 assert example.get_aligned("ORTH", as_string=True) == expected2
def test_training_before_update(doc):
def before_update(nlp, args):
assert args["step"] == 0
assert args["epoch"] == 1
# Raise an error here as the rest of the loop
# will not run to completion due to uninitialized
# models.
raise ValueError("ran_before_update")
def generate_batch():
yield 1, [Example(doc, doc)]
nlp = spacy.blank("en")
nlp.add_pipe("tagger")
optimizer = Adam()
generator = train_while_improving(
nlp,
optimizer,
generate_batch(),
lambda: None,
dropout=0.1,
eval_frequency=100,
accumulate_gradient=10,
patience=10,
max_steps=100,
exclude=[],
annotating_components=[],
before_update=before_update,
)
with pytest.raises(ValueError, match="ran_before_update"):
for _ in generator:
pass

View File

@ -59,6 +59,7 @@ def train(
batcher = T["batcher"] batcher = T["batcher"]
train_logger = T["logger"] train_logger = T["logger"]
before_to_disk = create_before_to_disk_callback(T["before_to_disk"]) before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
before_update = T["before_update"]
# Helper function to save checkpoints. This is a closure for convenience, # Helper function to save checkpoints. This is a closure for convenience,
# to avoid passing in all the args all the time. # to avoid passing in all the args all the time.
@ -89,6 +90,7 @@ def train(
eval_frequency=T["eval_frequency"], eval_frequency=T["eval_frequency"],
exclude=frozen_components, exclude=frozen_components,
annotating_components=annotating_components, annotating_components=annotating_components,
before_update=before_update,
) )
clean_output_dir(output_path) clean_output_dir(output_path)
stdout.write(msg.info(f"Pipeline: {nlp.pipe_names}") + "\n") stdout.write(msg.info(f"Pipeline: {nlp.pipe_names}") + "\n")
@ -150,6 +152,7 @@ def train_while_improving(
max_steps: int, max_steps: int,
exclude: List[str], exclude: List[str],
annotating_components: List[str], annotating_components: List[str],
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
): ):
"""Train until an evaluation stops improving. Works as a generator, """Train until an evaluation stops improving. Works as a generator,
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`, with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
@ -198,6 +201,9 @@ def train_while_improving(
words_seen = 0 words_seen = 0
start_time = timer() start_time = timer()
for step, (epoch, batch) in enumerate(train_data): for step, (epoch, batch) in enumerate(train_data):
if before_update:
before_update_args = {"step": step, "epoch": epoch}
before_update(nlp, before_update_args)
dropout = next(dropouts) # type: ignore dropout = next(dropouts) # type: ignore
for subbatch in subdivide_batch(batch, accumulate_gradient): for subbatch in subdivide_batch(batch, accumulate_gradient):
nlp.update( nlp.update(

View File

@ -186,6 +186,7 @@ process that are used when you run [`spacy train`](/api/cli#train).
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ | | `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ | | `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
| `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ | | `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ |
| `before_update` | Optional callback that is invoked at the start of each training step with the `nlp` object and a `Dict` containing the following entries: `step`, `epoch`. Can be used to make deferred changes to components. Defaults to `null`. ~~Optional[Callable[[Language, Dict[str, Any]], None]]~~ |
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ | | `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ | | `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ | | `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |