mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Add rehearse_components to default config
This commit is contained in:
parent
6c81205978
commit
5050498661
|
@ -84,6 +84,8 @@ score_weights = {}
|
|||
frozen_components = []
|
||||
# Names of pipeline components that should set annotations during training
|
||||
annotating_components = []
|
||||
# Names of pipeline components that should get rehearsed during training
|
||||
rehearse_components = []
|
||||
# Location in the config where the dev corpus is defined
|
||||
dev_corpus = "corpora.dev"
|
||||
# Location in the config where the train corpus is defined
|
||||
|
|
|
@ -356,7 +356,7 @@ class ConfigSchemaTraining(BaseModel):
|
|||
logger: Logger = Field(..., title="The logger to track training progress")
|
||||
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
||||
annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
|
||||
rehearse_components: Optional[List[str]] = Field(..., title="Pipeline components that should be rehearsed during training")
|
||||
rehearse_components: List[str] = Field(..., title="Pipeline components that should be rehearsed during training")
|
||||
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
|
||||
before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step")
|
||||
# fmt: on
|
||||
|
|
|
@ -60,8 +60,6 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
optimizer = T["optimizer"]
|
||||
# Components that shouldn't be updated during training
|
||||
frozen_components = T["frozen_components"]
|
||||
# Components that shouldn't be updated during training
|
||||
rehearse_components = T["rehearse_components"]
|
||||
# Sourced components that require resume_training
|
||||
resume_components = [p for p in sourced if p not in frozen_components]
|
||||
logger.info(f"Pipeline: {nlp.pipe_names}")
|
||||
|
@ -69,6 +67,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
with nlp.select_pipes(enable=resume_components):
|
||||
logger.info(f"Resuming training for: {resume_components}")
|
||||
nlp.resume_training(sgd=optimizer)
|
||||
# Components that shouldn't be updated during training
|
||||
rehearse_components = T["rehearse_components"]
|
||||
if rehearse_components:
|
||||
logger.info(f"Rehearsing components: {rehearse_components}")
|
||||
# Make sure that listeners are defined before initializing further
|
||||
|
|
Loading…
Reference in New Issue
Block a user