mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Add rehearse_components to default config
This commit is contained in:
		
							parent
							
								
									6c81205978
								
							
						
					
					
						commit
						5050498661
					
				| 
						 | 
					@ -84,6 +84,8 @@ score_weights = {}
 | 
				
			||||||
frozen_components = []
 | 
					frozen_components = []
 | 
				
			||||||
# Names of pipeline components that should set annotations during training
 | 
					# Names of pipeline components that should set annotations during training
 | 
				
			||||||
annotating_components = []
 | 
					annotating_components = []
 | 
				
			||||||
 | 
					# Names of pipeline components that should get rehearsed during training
 | 
				
			||||||
 | 
					rehearse_components = []
 | 
				
			||||||
# Location in the config where the dev corpus is defined
 | 
					# Location in the config where the dev corpus is defined
 | 
				
			||||||
dev_corpus = "corpora.dev"
 | 
					dev_corpus = "corpora.dev"
 | 
				
			||||||
# Location in the config where the train corpus is defined
 | 
					# Location in the config where the train corpus is defined
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -356,7 +356,7 @@ class ConfigSchemaTraining(BaseModel):
 | 
				
			||||||
    logger: Logger = Field(..., title="The logger to track training progress")
 | 
					    logger: Logger = Field(..., title="The logger to track training progress")
 | 
				
			||||||
    frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
 | 
					    frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
 | 
				
			||||||
    annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
 | 
					    annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
 | 
				
			||||||
    rehearse_components: Optional[List[str]] = Field(..., title="Pipeline components that should be rehearsed during training")
 | 
					    rehearse_components: List[str] = Field(..., title="Pipeline components that should be rehearsed during training")
 | 
				
			||||||
    before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
 | 
					    before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
 | 
				
			||||||
    before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step")
 | 
					    before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step")
 | 
				
			||||||
    # fmt: on
 | 
					    # fmt: on
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,8 +60,6 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
 | 
				
			||||||
    optimizer = T["optimizer"]
 | 
					    optimizer = T["optimizer"]
 | 
				
			||||||
    # Components that shouldn't be updated during training
 | 
					    # Components that shouldn't be updated during training
 | 
				
			||||||
    frozen_components = T["frozen_components"]
 | 
					    frozen_components = T["frozen_components"]
 | 
				
			||||||
    # Components that shouldn't be updated during training
 | 
					 | 
				
			||||||
    rehearse_components = T["rehearse_components"]
 | 
					 | 
				
			||||||
    # Sourced components that require resume_training
 | 
					    # Sourced components that require resume_training
 | 
				
			||||||
    resume_components = [p for p in sourced if p not in frozen_components]
 | 
					    resume_components = [p for p in sourced if p not in frozen_components]
 | 
				
			||||||
    logger.info(f"Pipeline: {nlp.pipe_names}")
 | 
					    logger.info(f"Pipeline: {nlp.pipe_names}")
 | 
				
			||||||
| 
						 | 
					@ -69,6 +67,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
 | 
				
			||||||
        with nlp.select_pipes(enable=resume_components):
 | 
					        with nlp.select_pipes(enable=resume_components):
 | 
				
			||||||
            logger.info(f"Resuming training for: {resume_components}")
 | 
					            logger.info(f"Resuming training for: {resume_components}")
 | 
				
			||||||
            nlp.resume_training(sgd=optimizer)
 | 
					            nlp.resume_training(sgd=optimizer)
 | 
				
			||||||
 | 
					    # Components that shouldn't be updated during training
 | 
				
			||||||
 | 
					    rehearse_components = T["rehearse_components"]
 | 
				
			||||||
    if rehearse_components:
 | 
					    if rehearse_components:
 | 
				
			||||||
        logger.info(f"Rehearsing components: {rehearse_components}")
 | 
					        logger.info(f"Rehearsing components: {rehearse_components}")
 | 
				
			||||||
    # Make sure that listeners are defined before initializing further
 | 
					    # Make sure that listeners are defined before initializing further
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user