diff --git a/spacy/cli/train.py b/spacy/cli/train.py index eabc82be0..6d61c2425 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -97,6 +97,7 @@ def train( dev_corpus = dot_to_object(config, T_cfg["dev_corpus"]) batcher = T_cfg["batcher"] train_logger = T_cfg["logger"] + before_to_disk = create_before_to_disk_callback(T_cfg["before_to_disk"]) # Components that shouldn't be updated during training frozen_components = T_cfg["frozen_components"] # Sourced components that require resume_training @@ -167,6 +168,7 @@ def train( with nlp.select_pipes(disable=frozen_components): update_meta(T_cfg, nlp, info) with nlp.use_params(optimizer.averages): + nlp = before_to_disk(nlp) nlp.to_disk(output_path / "model-best") progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False) progress.set_description(f"Epoch {info['epoch']}") @@ -179,6 +181,7 @@ def train( f"Aborting and saving the final best model. " f"Encountered exception: {str(e)}" ) + nlp = before_to_disk(nlp) nlp.to_disk(output_path / "model-final") raise e finally: @@ -233,6 +236,21 @@ def create_evaluation_callback( return evaluate +def create_before_to_disk_callback( + callback: Optional[Callable[[Language], Language]] +) -> Callable[[Language], Language]: + def before_to_disk(nlp: Language) -> Language: + if not callback: + return nlp + modified_nlp = callback(nlp) + if not isinstance(modified_nlp, Language): + err = Errors.E914.format(name="before_to_disk", value=type(modified_nlp)) + raise ValueError(err) + return modified_nlp + + return before_to_disk + + def train_while_improving( nlp: Language, optimizer: Optimizer, diff --git a/spacy/default_config.cfg b/spacy/default_config.cfg index 5cd97a0eb..6f8c0aa00 100644 --- a/spacy/default_config.cfg +++ b/spacy/default_config.cfg @@ -72,6 +72,8 @@ frozen_components = [] dev_corpus = "corpora.dev" # Location in the config where the train corpus is defined train_corpus = "corpora.train" +# Optional callback before nlp object is saved to disk after training +before_to_disk = null [training.logger] @loggers = "spacy.ConsoleLogger.v1" diff --git a/spacy/errors.py b/spacy/errors.py index dce5cf51c..d67f01a1d 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -480,6 +480,9 @@ class Errors: E201 = ("Span index out of range.") # TODO: fix numbering after merging develop into master + E914 = ("Executing {name} callback failed. Expected the function to " + "returnthe nlp object but got: {value}. Maybe you forgot to return " + "the modified object in your function?") E915 = ("Can't use score '{name}' to calculate final weighted score. Expected " "float or int but got: {score_type}. To exclude the score from the " "final score, set its weight to null in the [training.score_weights] " diff --git a/spacy/schemas.py b/spacy/schemas.py index e34841008..6a9a82d06 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -217,6 +217,7 @@ class ConfigSchemaTraining(BaseModel): optimizer: Optimizer = Field(..., title="The optimizer to use") logger: Logger = Field(..., title="The logger to track training progress") frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training") + before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk") # fmt: on class Config: