Add rehearse_components to default config

This commit is contained in:
thomashacker 2023-01-19 16:21:14 +01:00
parent 6c81205978
commit 5050498661
3 changed files with 5 additions and 3 deletions

View File

@ -84,6 +84,8 @@ score_weights = {}
frozen_components = []
# Names of pipeline components that should set annotations during training
annotating_components = []
# Names of pipeline components that should get rehearsed during training
rehearse_components = []
# Location in the config where the dev corpus is defined
dev_corpus = "corpora.dev"
# Location in the config where the train corpus is defined

View File

@ -356,7 +356,7 @@ class ConfigSchemaTraining(BaseModel):
logger: Logger = Field(..., title="The logger to track training progress")
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
rehearse_components: Optional[List[str]] = Field(..., title="Pipeline components that should be rehearsed during training")
rehearse_components: List[str] = Field(..., title="Pipeline components that should be rehearsed during training")
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step")
# fmt: on

View File

@ -60,8 +60,6 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
optimizer = T["optimizer"]
# Components that shouldn't be updated during training
frozen_components = T["frozen_components"]
# Components that shouldn't be updated during training
rehearse_components = T["rehearse_components"]
# Sourced components that require resume_training
resume_components = [p for p in sourced if p not in frozen_components]
logger.info(f"Pipeline: {nlp.pipe_names}")
@ -69,6 +67,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
with nlp.select_pipes(enable=resume_components):
logger.info(f"Resuming training for: {resume_components}")
nlp.resume_training(sgd=optimizer)
# Components that shouldn't be updated during training
rehearse_components = T["rehearse_components"]
if rehearse_components:
logger.info(f"Rehearsing components: {rehearse_components}")
# Make sure that listeners are defined before initializing further