diff --git a/spacy/default_config.cfg b/spacy/default_config.cfg index 8bd2c7065..1b360dd7f 100644 --- a/spacy/default_config.cfg +++ b/spacy/default_config.cfg @@ -85,7 +85,7 @@ frozen_components = [] # Names of pipeline components that should set annotations during training annotating_components = [] # Names of pipeline components that should get rehearsed during training -rehearse_components = [] +rehearsal_components = [] # Location in the config where the dev corpus is defined dev_corpus = "corpora.dev" # Location in the config where the train corpus is defined diff --git a/spacy/language.py b/spacy/language.py index 133c9aa93..170041bbe 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1183,7 +1183,7 @@ class Language: losses: Optional[Dict[str, float]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, exclude: Iterable[str] = SimpleFrozenList(), - rehearse_components: List[str] = [], + rehearsal_components: List[str] = [], ) -> Dict[str, float]: """Make a "rehearsal" update to the models in the pipeline, to prevent forgetting. Rehearsal updates run an initial copy of the model over some @@ -1196,7 +1196,7 @@ class Language: component_cfg (Dict[str, Dict]): Config parameters for specific pipeline components, keyed by component name. exclude (Iterable[str]): Names of components that shouldn't be updated. - rehearse_components (List[str]): Names of components that should be rehearsed + rehearsal_components (List[str]): Names of components that should be rehearsed RETURNS (dict): Results from the update. EXAMPLE: @@ -1221,20 +1221,12 @@ class Language: if ( name in exclude or not hasattr(proc, "rehearse") - or name not in rehearse_components + or name not in rehearsal_components ): continue - proc.rehearse( # type: ignore[attr-defined] - examples, sgd=None, losses=losses, **component_cfg.get(name, {}) + proc.rehearse( + examples, sgd=sgd, losses=losses, **component_cfg.get(name, {}) ) - if isinstance(sgd, Optimizer): - if ( - name not in exclude - and isinstance(proc, ty.TrainableComponent) - and proc.is_trainable - and proc.model not in (True, False, None) - ): - proc.finish_update(sgd) return losses diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index b1698d11c..f236bc20e 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -228,7 +228,7 @@ class Tagger(TrainablePipe): loss_func = SequenceCategoricalCrossentropy() if losses is None: losses = {} - losses.setdefault(self.name, 0.0) + losses.setdefault(self.name+"_rehearse", 0.0) validate_examples(examples, "Tagger.rehearse") docs = [eg.predicted for eg in examples] if self._rehearsal_model is None: @@ -243,7 +243,7 @@ class Tagger(TrainablePipe): bp_tag_scores(grads) if sgd is not None: self.finish_update(sgd) - losses[self.name] += loss + losses[self.name+"_rehearse"] += loss return losses def get_loss(self, examples, scores): diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 650a01949..8b9762c4a 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -276,7 +276,7 @@ class TextCategorizer(TrainablePipe): """ if losses is None: losses = {} - losses.setdefault(self.name, 0.0) + losses.setdefault(self.name+"_rehearse", 0.0) if self._rehearsal_model is None: return losses validate_examples(examples, "TextCategorizer.rehearse") @@ -292,7 +292,7 @@ class TextCategorizer(TrainablePipe): bp_scores(gradient) if sgd is not None: self.finish_update(sgd) - losses[self.name] += (gradient**2).sum() + losses[self.name+"_rehearse"] += (gradient**2).sum() return losses def _examples_to_truth( diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 1327db2ce..ed58b41a5 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -444,7 +444,7 @@ cdef class Parser(TrainablePipe): multitask.rehearse(examples, losses=losses, sgd=sgd) if self._rehearsal_model is None: return None - losses.setdefault(self.name, 0.) + losses.setdefault(self.name+"_rehearse", 0.) validate_examples(examples, "Parser.rehearse") docs = [eg.predicted for eg in examples] states = self.moves.init_batch(docs) @@ -475,7 +475,7 @@ cdef class Parser(TrainablePipe): backprop_tok2vec(docs) if sgd is not None: self.finish_update(sgd) - losses[self.name] += loss / n_scores + losses[self.name+"_rehearse"] += loss / n_scores del backprop del backprop_tok2vec model.clear_memory() diff --git a/spacy/schemas.py b/spacy/schemas.py index cab909e21..f707c7dcf 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -356,7 +356,7 @@ class ConfigSchemaTraining(BaseModel): logger: Logger = Field(..., title="The logger to track training progress") frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training") annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training") - rehearse_components: List[str] = Field(..., title="Pipeline components that should be rehearsed during training") + rehearsal_components: List[str] = Field(..., title="Pipeline components that should be rehearsed during training") before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk") before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step") # fmt: on diff --git a/spacy/tests/training/test_training.py b/spacy/tests/training/test_training.py index bff84e924..65b03c30a 100644 --- a/spacy/tests/training/test_training.py +++ b/spacy/tests/training/test_training.py @@ -1144,7 +1144,7 @@ def test_training_before_update(doc): max_steps=100, exclude=[], annotating_components=[], - rehearse_components=[], + rehearsal_components=[], before_update=before_update, ) diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index 1d2184dd1..9453959bf 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -67,10 +67,10 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": with nlp.select_pipes(enable=resume_components): logger.info(f"Resuming training for: {resume_components}") nlp.resume_training(sgd=optimizer) - # Components that shouldn't be updated during training - rehearse_components = T["rehearse_components"] - if rehearse_components: - logger.info(f"Rehearsing components: {rehearse_components}") + # Components that should be updated during training + rehearsal_components = T["rehearsal_components"] + if rehearsal_components: + logger.info(f"Rehearsing components: {rehearsal_components}") # Make sure that listeners are defined before initializing further nlp._link_components() with nlp.select_pipes(disable=[*frozen_components, *resume_components]): diff --git a/spacy/training/loop.py b/spacy/training/loop.py index 9368384ae..233b5e642 100644 --- a/spacy/training/loop.py +++ b/spacy/training/loop.py @@ -78,7 +78,7 @@ def train( # Components that should set annotations on update annotating_components = T["annotating_components"] # Components that should be rehearsed after update - rehearse_components = T["rehearse_components"] + rehearsal_components = T["rehearsal_components"] # Create iterator, which yields out info after each optimization step. training_step_iterator = train_while_improving( nlp, @@ -92,7 +92,7 @@ def train( eval_frequency=T["eval_frequency"], exclude=frozen_components, annotating_components=annotating_components, - rehearse_components=rehearse_components, + rehearsal_components=rehearsal_components, before_update=before_update, ) clean_output_dir(output_path) @@ -155,7 +155,7 @@ def train_while_improving( max_steps: int, exclude: List[str], annotating_components: List[str], - rehearse_components: List[str], + rehearsal_components: List[str], before_update: Optional[Callable[["Language", Dict[str, Any]], None]], ): """Train until an evaluation stops improving. Works as a generator, @@ -223,7 +223,7 @@ def train_while_improving( losses=losses, sgd=None, exclude=exclude, - rehearse_components=rehearse_components, + rehearsal_components=rehearsal_components, ) # TODO: refactor this so we don't have to run it separately in here for name, proc in nlp.pipeline: diff --git a/website/docs/api/data-formats.mdx b/website/docs/api/data-formats.mdx index 41ef131d9..225c9c6b3 100644 --- a/website/docs/api/data-formats.mdx +++ b/website/docs/api/data-formats.mdx @@ -181,27 +181,27 @@ single corpus once and then divide it up into `train` and `dev` partitions. This section defines settings and controls for the training and evaluation process that are used when you run [`spacy train`](/api/cli#train). -| Name | Description | -| ---------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ | -| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ | -| `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ | -| `before_update` 3.5 | Optional callback that is invoked at the start of each training step with the `nlp` object and a `Dict` containing the following entries: `step`, `epoch`. Can be used to make deferred changes to components. Defaults to `null`. ~~Optional[Callable[[Language, Dict[str, Any]], None]]~~ | -| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ | -| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ | -| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ | -| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be initialized or updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ | -| `annotating_components` 3.1 | Pipeline component names that should set annotations on the predicted docs during training. See [here](/usage/training#annotating-components) for details. Defaults to `[]`. ~~List[str]~~ | -| `rehearse_components` 3.5.1 | Pipeline component names that should get rehearsed during training. See [here](/usage/training#rehearse-components) for details. Defaults to `[]`. ~~List[str]~~ | -| `gpu_allocator` | Library for cupy to route GPU memory allocation to. Can be `"pytorch"` or `"tensorflow"`. Defaults to variable `${system.gpu_allocator}`. ~~str~~ | -| `logger` | Callable that takes the `nlp` and stdout and stderr `IO` objects, sets up the logger, and returns two new callables to log a training step and to finalize the logger. Defaults to [`ConsoleLogger`](/api/top-level#ConsoleLogger). ~~Callable[[Language, IO, IO], [Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]]]~~ | -| `max_epochs` | Maximum number of epochs to train for. `0` means an unlimited number of epochs. `-1` means that the train corpus should be streamed rather than loaded into memory with no shuffling within the training loop. Defaults to `0`. ~~int~~ | -| `max_steps` | Maximum number of update steps to train for. `0` means an unlimited number of steps. Defaults to `20000`. ~~int~~ | -| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ | -| `patience` | How many steps to continue without improvement in evaluation score. `0` disables early stopping. Defaults to `1600`. ~~int~~ | -| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ | -| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ | -| `train_corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ | +| Name | Description | +| ----------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ | +| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ | +| `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ | +| `before_update` 3.5 | Optional callback that is invoked at the start of each training step with the `nlp` object and a `Dict` containing the following entries: `step`, `epoch`. Can be used to make deferred changes to components. Defaults to `null`. ~~Optional[Callable[[Language, Dict[str, Any]], None]]~~ | +| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ | +| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ | +| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ | +| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be initialized or updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ | +| `annotating_components` 3.1 | Pipeline component names that should set annotations on the predicted docs during training. See [here](/usage/training#annotating-components) for details. Defaults to `[]`. ~~List[str]~~ | +| `rehearsal_components` 3.5.1 | Pipeline component names that should get rehearsed during training. See [here](/usage/training#rehearse-components) for details. Defaults to `[]`. ~~List[str]~~ | +| `gpu_allocator` | Library for cupy to route GPU memory allocation to. Can be `"pytorch"` or `"tensorflow"`. Defaults to variable `${system.gpu_allocator}`. ~~str~~ | +| `logger` | Callable that takes the `nlp` and stdout and stderr `IO` objects, sets up the logger, and returns two new callables to log a training step and to finalize the logger. Defaults to [`ConsoleLogger`](/api/top-level#ConsoleLogger). ~~Callable[[Language, IO, IO], [Tuple[Callable[[Dict[str, Any]], None], Callable[[], None]]]]~~ | +| `max_epochs` | Maximum number of epochs to train for. `0` means an unlimited number of epochs. `-1` means that the train corpus should be streamed rather than loaded into memory with no shuffling within the training loop. Defaults to `0`. ~~int~~ | +| `max_steps` | Maximum number of update steps to train for. `0` means an unlimited number of steps. Defaults to `20000`. ~~int~~ | +| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ | +| `patience` | How many steps to continue without improvement in evaluation score. `0` disables early stopping. Defaults to `1600`. ~~int~~ | +| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ | +| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ | +| `train_corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ | ### pretraining {id="config-pretraining",tag="section,optional"} diff --git a/website/docs/api/language.mdx b/website/docs/api/language.mdx index 9831030d4..a744e89fc 100644 --- a/website/docs/api/language.mdx +++ b/website/docs/api/language.mdx @@ -346,7 +346,14 @@ and custom registered functions if needed. See the Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the current model to make predictions similar to an initial model, to try to address -the "catastrophic forgetting" problem. Please note that `Language.rehearse` needs to be used together with `Language.update`. This feature is experimental. +the "catastrophic forgetting" problem. + + + +Note that `Language.rehearse` needs to be used together with `Language.update`. +This feature is experimental. + + > #### Example > diff --git a/website/docs/api/pipe.mdx b/website/docs/api/pipe.mdx index 744186e18..4f72ef0ff 100644 --- a/website/docs/api/pipe.mdx +++ b/website/docs/api/pipe.mdx @@ -244,7 +244,14 @@ predictions and gold-standard annotations, and update the component's model. Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the current model to make predictions similar to an initial model, to try to address -the "catastrophic forgetting" problem. Please note that `TrainablePipe.update` needs to be used together with `TrainablePipe.update`. This feature is experimental. +the "catastrophic forgetting" problem. + + + +Note that `TrainablePipe.rehearse` needs to be used together with +`TrainablePipe.update`. This feature is experimental. + + > #### Example > diff --git a/website/docs/usage/training.mdx b/website/docs/usage/training.mdx index 0a05037c5..bae9ef326 100644 --- a/website/docs/usage/training.mdx +++ b/website/docs/usage/training.mdx @@ -577,7 +577,15 @@ now-updated model to the predicted docs. ### Using rehearsing to address catastrophic forgetting {id="rehearse-components", tag="experimental", version="3.5.1"} -Perform “rehearsal” updates to pre-trained components. Rehearsal updates teach the current component to make predictions similar to an initial model, to try to address the “catastrophic forgetting” problem. This feature is experimental. +When fine-tuning pre-trained components, we can perform an additional +`rehearsal` update after every regular update to address the problem of +[`catastrophic forgetting`](https://explosion.ai/blog/pseudo-rehearsal-catastrophic-forgetting). +These updates teach the fine-tuned component to make predictions similar to its +initial, pre-trained version. + +Perform “rehearsal” updates to pre-trained components. Rehearsal updates teach +the current component to make predictions similar to an initial model, to try to +address the “catastrophic forgetting” problem. This feature is experimental. ```ini {title="config.cfg (excerpt)"} [nlp] @@ -587,17 +595,9 @@ pipeline = ["sentencizer", "ner", "entity_linker"] source = "en_core_web_sm" [training] -rehearse_components = ["ner"] +rehearsal_components = ["ner"] ``` - - -Be aware that the loss is calculated by the sum of both the `update` and `rehearse` function. -If both the loss and accuracy of the component increases over time, it can be caused due to the trained component making more different predictions that the inital model, -indicating `catastrophic forgetting`. - - - ### Using registered functions {id="config-functions"} The training configuration defined in the config file doesn't have to only