mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Add training.before_update
callback (#11739)
* Add `training.before_update` callback This callback can be used to implement training paradigms like gradual (un)freezing of components (e.g: the Transformer) after a certain number of training steps to mitigate catastrophic forgetting during fine-tuning. * Fix type annotation, default config value * Generalize arguments passed to the callback * Update schema * Pass `epoch` to callback, rename `current_step` to `step` * Add test * Simplify test * Replace config string with `spacy.blank` * Apply suggestions from code review Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Cleanup imports Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
8271cfb4cd
commit
5ea14af32b
|
@ -90,6 +90,8 @@ dev_corpus = "corpora.dev"
|
|||
train_corpus = "corpora.train"
|
||||
# Optional callback before nlp object is saved to disk after training
|
||||
before_to_disk = null
|
||||
# Optional callback that is invoked at the start of each training step
|
||||
before_update = null
|
||||
|
||||
[training.logger]
|
||||
@loggers = "spacy.ConsoleLogger.v1"
|
||||
|
|
|
@ -329,6 +329,7 @@ class ConfigSchemaTraining(BaseModel):
|
|||
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
||||
annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
|
||||
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
|
||||
before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
|
|
|
@ -2,6 +2,7 @@ import random
|
|||
|
||||
import numpy
|
||||
import pytest
|
||||
import spacy
|
||||
import srsly
|
||||
from spacy.lang.en import English
|
||||
from spacy.tokens import Doc, DocBin
|
||||
|
@ -11,9 +12,10 @@ from spacy.training import offsets_to_biluo_tags
|
|||
from spacy.training.alignment_array import AlignmentArray
|
||||
from spacy.training.align import get_alignments
|
||||
from spacy.training.converters import json_to_docs
|
||||
from spacy.training.loop import train_while_improving
|
||||
from spacy.util import get_words_and_spaces, load_model_from_path, minibatch
|
||||
from spacy.util import load_config_from_str
|
||||
from thinc.api import compounding
|
||||
from thinc.api import compounding, Adam
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -1112,3 +1114,39 @@ def test_retokenized_docs(doc):
|
|||
retokenizer.merge(doc1[0:2])
|
||||
retokenizer.merge(doc1[5:7])
|
||||
assert example.get_aligned("ORTH", as_string=True) == expected2
|
||||
|
||||
|
||||
def test_training_before_update(doc):
|
||||
def before_update(nlp, args):
|
||||
assert args["step"] == 0
|
||||
assert args["epoch"] == 1
|
||||
|
||||
# Raise an error here as the rest of the loop
|
||||
# will not run to completion due to uninitialized
|
||||
# models.
|
||||
raise ValueError("ran_before_update")
|
||||
|
||||
def generate_batch():
|
||||
yield 1, [Example(doc, doc)]
|
||||
|
||||
nlp = spacy.blank("en")
|
||||
nlp.add_pipe("tagger")
|
||||
optimizer = Adam()
|
||||
generator = train_while_improving(
|
||||
nlp,
|
||||
optimizer,
|
||||
generate_batch(),
|
||||
lambda: None,
|
||||
dropout=0.1,
|
||||
eval_frequency=100,
|
||||
accumulate_gradient=10,
|
||||
patience=10,
|
||||
max_steps=100,
|
||||
exclude=[],
|
||||
annotating_components=[],
|
||||
before_update=before_update,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="ran_before_update"):
|
||||
for _ in generator:
|
||||
pass
|
||||
|
|
|
@ -59,6 +59,7 @@ def train(
|
|||
batcher = T["batcher"]
|
||||
train_logger = T["logger"]
|
||||
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
||||
before_update = T["before_update"]
|
||||
|
||||
# Helper function to save checkpoints. This is a closure for convenience,
|
||||
# to avoid passing in all the args all the time.
|
||||
|
@ -89,6 +90,7 @@ def train(
|
|||
eval_frequency=T["eval_frequency"],
|
||||
exclude=frozen_components,
|
||||
annotating_components=annotating_components,
|
||||
before_update=before_update,
|
||||
)
|
||||
clean_output_dir(output_path)
|
||||
stdout.write(msg.info(f"Pipeline: {nlp.pipe_names}") + "\n")
|
||||
|
@ -150,6 +152,7 @@ def train_while_improving(
|
|||
max_steps: int,
|
||||
exclude: List[str],
|
||||
annotating_components: List[str],
|
||||
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
|
||||
):
|
||||
"""Train until an evaluation stops improving. Works as a generator,
|
||||
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
||||
|
@ -198,6 +201,9 @@ def train_while_improving(
|
|||
words_seen = 0
|
||||
start_time = timer()
|
||||
for step, (epoch, batch) in enumerate(train_data):
|
||||
if before_update:
|
||||
before_update_args = {"step": step, "epoch": epoch}
|
||||
before_update(nlp, before_update_args)
|
||||
dropout = next(dropouts) # type: ignore
|
||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||
nlp.update(
|
||||
|
|
|
@ -186,6 +186,7 @@ process that are used when you run [`spacy train`](/api/cli#train).
|
|||
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
|
||||
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
|
||||
| `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ |
|
||||
| `before_update` | Optional callback that is invoked at the start of each training step with the `nlp` object and a `Dict` containing the following entries: `step`, `epoch`. Can be used to make deferred changes to components. Defaults to `null`. ~~Optional[Callable[[Language, Dict[str, Any]], None]]~~ |
|
||||
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
|
||||
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
|
||||
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
|
||||
|
|
Loading…
Reference in New Issue
Block a user