diff --git a/spacy/default_config.cfg b/spacy/default_config.cfg index 694fb732f..8bd2c7065 100644 --- a/spacy/default_config.cfg +++ b/spacy/default_config.cfg @@ -84,6 +84,8 @@ score_weights = {} frozen_components = [] # Names of pipeline components that should set annotations during training annotating_components = [] +# Names of pipeline components that should get rehearsed during training +rehearse_components = [] # Location in the config where the dev corpus is defined dev_corpus = "corpora.dev" # Location in the config where the train corpus is defined diff --git a/spacy/schemas.py b/spacy/schemas.py index 9bdff3030..cab909e21 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -356,7 +356,7 @@ class ConfigSchemaTraining(BaseModel): logger: Logger = Field(..., title="The logger to track training progress") frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training") annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training") - rehearse_components: Optional[List[str]] = Field(..., title="Pipeline components that should be rehearsed during training") + rehearse_components: List[str] = Field(..., title="Pipeline components that should be rehearsed during training") before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk") before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step") # fmt: on diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index 2315c1140..1d2184dd1 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -60,8 +60,6 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": optimizer = T["optimizer"] # Components that shouldn't be updated during training frozen_components = T["frozen_components"] - # Components that shouldn't be updated during training - rehearse_components = T["rehearse_components"] # Sourced components that require resume_training resume_components = [p for p in sourced if p not in frozen_components] logger.info(f"Pipeline: {nlp.pipe_names}") @@ -69,6 +67,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": with nlp.select_pipes(enable=resume_components): logger.info(f"Resuming training for: {resume_components}") nlp.resume_training(sgd=optimizer) + # Components that shouldn't be updated during training + rehearse_components = T["rehearse_components"] if rehearse_components: logger.info(f"Rehearsing components: {rehearse_components}") # Make sure that listeners are defined before initializing further