mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Add training.before_update
callback (#11739)
* Add `training.before_update` callback This callback can be used to implement training paradigms like gradual (un)freezing of components (e.g: the Transformer) after a certain number of training steps to mitigate catastrophic forgetting during fine-tuning. * Fix type annotation, default config value * Generalize arguments passed to the callback * Update schema * Pass `epoch` to callback, rename `current_step` to `step` * Add test * Simplify test * Replace config string with `spacy.blank` * Apply suggestions from code review Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Cleanup imports Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
8271cfb4cd
commit
5ea14af32b
|
@ -90,6 +90,8 @@ dev_corpus = "corpora.dev"
|
||||||
train_corpus = "corpora.train"
|
train_corpus = "corpora.train"
|
||||||
# Optional callback before nlp object is saved to disk after training
|
# Optional callback before nlp object is saved to disk after training
|
||||||
before_to_disk = null
|
before_to_disk = null
|
||||||
|
# Optional callback that is invoked at the start of each training step
|
||||||
|
before_update = null
|
||||||
|
|
||||||
[training.logger]
|
[training.logger]
|
||||||
@loggers = "spacy.ConsoleLogger.v1"
|
@loggers = "spacy.ConsoleLogger.v1"
|
||||||
|
|
|
@ -329,6 +329,7 @@ class ConfigSchemaTraining(BaseModel):
|
||||||
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
||||||
annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
|
annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
|
||||||
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
|
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
|
||||||
|
before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
|
@ -2,6 +2,7 @@ import random
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import pytest
|
import pytest
|
||||||
|
import spacy
|
||||||
import srsly
|
import srsly
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.tokens import Doc, DocBin
|
from spacy.tokens import Doc, DocBin
|
||||||
|
@ -11,9 +12,10 @@ from spacy.training import offsets_to_biluo_tags
|
||||||
from spacy.training.alignment_array import AlignmentArray
|
from spacy.training.alignment_array import AlignmentArray
|
||||||
from spacy.training.align import get_alignments
|
from spacy.training.align import get_alignments
|
||||||
from spacy.training.converters import json_to_docs
|
from spacy.training.converters import json_to_docs
|
||||||
|
from spacy.training.loop import train_while_improving
|
||||||
from spacy.util import get_words_and_spaces, load_model_from_path, minibatch
|
from spacy.util import get_words_and_spaces, load_model_from_path, minibatch
|
||||||
from spacy.util import load_config_from_str
|
from spacy.util import load_config_from_str
|
||||||
from thinc.api import compounding
|
from thinc.api import compounding, Adam
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
@ -1112,3 +1114,39 @@ def test_retokenized_docs(doc):
|
||||||
retokenizer.merge(doc1[0:2])
|
retokenizer.merge(doc1[0:2])
|
||||||
retokenizer.merge(doc1[5:7])
|
retokenizer.merge(doc1[5:7])
|
||||||
assert example.get_aligned("ORTH", as_string=True) == expected2
|
assert example.get_aligned("ORTH", as_string=True) == expected2
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_before_update(doc):
|
||||||
|
def before_update(nlp, args):
|
||||||
|
assert args["step"] == 0
|
||||||
|
assert args["epoch"] == 1
|
||||||
|
|
||||||
|
# Raise an error here as the rest of the loop
|
||||||
|
# will not run to completion due to uninitialized
|
||||||
|
# models.
|
||||||
|
raise ValueError("ran_before_update")
|
||||||
|
|
||||||
|
def generate_batch():
|
||||||
|
yield 1, [Example(doc, doc)]
|
||||||
|
|
||||||
|
nlp = spacy.blank("en")
|
||||||
|
nlp.add_pipe("tagger")
|
||||||
|
optimizer = Adam()
|
||||||
|
generator = train_while_improving(
|
||||||
|
nlp,
|
||||||
|
optimizer,
|
||||||
|
generate_batch(),
|
||||||
|
lambda: None,
|
||||||
|
dropout=0.1,
|
||||||
|
eval_frequency=100,
|
||||||
|
accumulate_gradient=10,
|
||||||
|
patience=10,
|
||||||
|
max_steps=100,
|
||||||
|
exclude=[],
|
||||||
|
annotating_components=[],
|
||||||
|
before_update=before_update,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="ran_before_update"):
|
||||||
|
for _ in generator:
|
||||||
|
pass
|
||||||
|
|
|
@ -59,6 +59,7 @@ def train(
|
||||||
batcher = T["batcher"]
|
batcher = T["batcher"]
|
||||||
train_logger = T["logger"]
|
train_logger = T["logger"]
|
||||||
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
||||||
|
before_update = T["before_update"]
|
||||||
|
|
||||||
# Helper function to save checkpoints. This is a closure for convenience,
|
# Helper function to save checkpoints. This is a closure for convenience,
|
||||||
# to avoid passing in all the args all the time.
|
# to avoid passing in all the args all the time.
|
||||||
|
@ -89,6 +90,7 @@ def train(
|
||||||
eval_frequency=T["eval_frequency"],
|
eval_frequency=T["eval_frequency"],
|
||||||
exclude=frozen_components,
|
exclude=frozen_components,
|
||||||
annotating_components=annotating_components,
|
annotating_components=annotating_components,
|
||||||
|
before_update=before_update,
|
||||||
)
|
)
|
||||||
clean_output_dir(output_path)
|
clean_output_dir(output_path)
|
||||||
stdout.write(msg.info(f"Pipeline: {nlp.pipe_names}") + "\n")
|
stdout.write(msg.info(f"Pipeline: {nlp.pipe_names}") + "\n")
|
||||||
|
@ -150,6 +152,7 @@ def train_while_improving(
|
||||||
max_steps: int,
|
max_steps: int,
|
||||||
exclude: List[str],
|
exclude: List[str],
|
||||||
annotating_components: List[str],
|
annotating_components: List[str],
|
||||||
|
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
|
||||||
):
|
):
|
||||||
"""Train until an evaluation stops improving. Works as a generator,
|
"""Train until an evaluation stops improving. Works as a generator,
|
||||||
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
||||||
|
@ -198,6 +201,9 @@ def train_while_improving(
|
||||||
words_seen = 0
|
words_seen = 0
|
||||||
start_time = timer()
|
start_time = timer()
|
||||||
for step, (epoch, batch) in enumerate(train_data):
|
for step, (epoch, batch) in enumerate(train_data):
|
||||||
|
if before_update:
|
||||||
|
before_update_args = {"step": step, "epoch": epoch}
|
||||||
|
before_update(nlp, before_update_args)
|
||||||
dropout = next(dropouts) # type: ignore
|
dropout = next(dropouts) # type: ignore
|
||||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||||
nlp.update(
|
nlp.update(
|
||||||
|
|
|
@ -186,6 +186,7 @@ process that are used when you run [`spacy train`](/api/cli#train).
|
||||||
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
|
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
|
||||||
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
|
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
|
||||||
| `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ |
|
| `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ |
|
||||||
|
| `before_update` | Optional callback that is invoked at the start of each training step with the `nlp` object and a `Dict` containing the following entries: `step`, `epoch`. Can be used to make deferred changes to components. Defaults to `null`. ~~Optional[Callable[[Language, Dict[str, Any]], None]]~~ |
|
||||||
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
|
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
|
||||||
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
|
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
|
||||||
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
|
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user