mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Add rehearse_components to default config
This commit is contained in:
parent
6c81205978
commit
5050498661
|
@ -84,6 +84,8 @@ score_weights = {}
|
||||||
frozen_components = []
|
frozen_components = []
|
||||||
# Names of pipeline components that should set annotations during training
|
# Names of pipeline components that should set annotations during training
|
||||||
annotating_components = []
|
annotating_components = []
|
||||||
|
# Names of pipeline components that should get rehearsed during training
|
||||||
|
rehearse_components = []
|
||||||
# Location in the config where the dev corpus is defined
|
# Location in the config where the dev corpus is defined
|
||||||
dev_corpus = "corpora.dev"
|
dev_corpus = "corpora.dev"
|
||||||
# Location in the config where the train corpus is defined
|
# Location in the config where the train corpus is defined
|
||||||
|
|
|
@ -356,7 +356,7 @@ class ConfigSchemaTraining(BaseModel):
|
||||||
logger: Logger = Field(..., title="The logger to track training progress")
|
logger: Logger = Field(..., title="The logger to track training progress")
|
||||||
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
||||||
annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
|
annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
|
||||||
rehearse_components: Optional[List[str]] = Field(..., title="Pipeline components that should be rehearsed during training")
|
rehearse_components: List[str] = Field(..., title="Pipeline components that should be rehearsed during training")
|
||||||
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
|
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
|
||||||
before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step")
|
before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
|
@ -60,8 +60,6 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
optimizer = T["optimizer"]
|
optimizer = T["optimizer"]
|
||||||
# Components that shouldn't be updated during training
|
# Components that shouldn't be updated during training
|
||||||
frozen_components = T["frozen_components"]
|
frozen_components = T["frozen_components"]
|
||||||
# Components that shouldn't be updated during training
|
|
||||||
rehearse_components = T["rehearse_components"]
|
|
||||||
# Sourced components that require resume_training
|
# Sourced components that require resume_training
|
||||||
resume_components = [p for p in sourced if p not in frozen_components]
|
resume_components = [p for p in sourced if p not in frozen_components]
|
||||||
logger.info(f"Pipeline: {nlp.pipe_names}")
|
logger.info(f"Pipeline: {nlp.pipe_names}")
|
||||||
|
@ -69,6 +67,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
with nlp.select_pipes(enable=resume_components):
|
with nlp.select_pipes(enable=resume_components):
|
||||||
logger.info(f"Resuming training for: {resume_components}")
|
logger.info(f"Resuming training for: {resume_components}")
|
||||||
nlp.resume_training(sgd=optimizer)
|
nlp.resume_training(sgd=optimizer)
|
||||||
|
# Components that shouldn't be updated during training
|
||||||
|
rehearse_components = T["rehearse_components"]
|
||||||
if rehearse_components:
|
if rehearse_components:
|
||||||
logger.info(f"Rehearsing components: {rehearse_components}")
|
logger.info(f"Rehearsing components: {rehearse_components}")
|
||||||
# Make sure that listeners are defined before initializing further
|
# Make sure that listeners are defined before initializing further
|
||||||
|
|
Loading…
Reference in New Issue
Block a user