Add [training.before_to_disk] callback

This commit is contained in:
Ines Montani 2020-09-24 12:40:25 +02:00
parent d7ab6a2ffe
commit be56c0994b
4 changed files with 24 additions and 0 deletions

View File

@ -97,6 +97,7 @@ def train(
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"]) dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
batcher = T_cfg["batcher"] batcher = T_cfg["batcher"]
train_logger = T_cfg["logger"] train_logger = T_cfg["logger"]
before_to_disk = create_before_to_disk_callback(T_cfg["before_to_disk"])
# Components that shouldn't be updated during training # Components that shouldn't be updated during training
frozen_components = T_cfg["frozen_components"] frozen_components = T_cfg["frozen_components"]
# Sourced components that require resume_training # Sourced components that require resume_training
@ -167,6 +168,7 @@ def train(
with nlp.select_pipes(disable=frozen_components): with nlp.select_pipes(disable=frozen_components):
update_meta(T_cfg, nlp, info) update_meta(T_cfg, nlp, info)
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
nlp = before_to_disk(nlp)
nlp.to_disk(output_path / "model-best") nlp.to_disk(output_path / "model-best")
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False) progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
progress.set_description(f"Epoch {info['epoch']}") progress.set_description(f"Epoch {info['epoch']}")
@ -179,6 +181,7 @@ def train(
f"Aborting and saving the final best model. " f"Aborting and saving the final best model. "
f"Encountered exception: {str(e)}" f"Encountered exception: {str(e)}"
) )
nlp = before_to_disk(nlp)
nlp.to_disk(output_path / "model-final") nlp.to_disk(output_path / "model-final")
raise e raise e
finally: finally:
@ -233,6 +236,21 @@ def create_evaluation_callback(
return evaluate return evaluate
def create_before_to_disk_callback(
callback: Optional[Callable[[Language], Language]]
) -> Callable[[Language], Language]:
def before_to_disk(nlp: Language) -> Language:
if not callback:
return nlp
modified_nlp = callback(nlp)
if not isinstance(modified_nlp, Language):
err = Errors.E914.format(name="before_to_disk", value=type(modified_nlp))
raise ValueError(err)
return modified_nlp
return before_to_disk
def train_while_improving( def train_while_improving(
nlp: Language, nlp: Language,
optimizer: Optimizer, optimizer: Optimizer,

View File

@ -72,6 +72,8 @@ frozen_components = []
dev_corpus = "corpora.dev" dev_corpus = "corpora.dev"
# Location in the config where the train corpus is defined # Location in the config where the train corpus is defined
train_corpus = "corpora.train" train_corpus = "corpora.train"
# Optional callback before nlp object is saved to disk after training
before_to_disk = null
[training.logger] [training.logger]
@loggers = "spacy.ConsoleLogger.v1" @loggers = "spacy.ConsoleLogger.v1"

View File

@ -480,6 +480,9 @@ class Errors:
E201 = ("Span index out of range.") E201 = ("Span index out of range.")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E914 = ("Executing {name} callback failed. Expected the function to "
"returnthe nlp object but got: {value}. Maybe you forgot to return "
"the modified object in your function?")
E915 = ("Can't use score '{name}' to calculate final weighted score. Expected " E915 = ("Can't use score '{name}' to calculate final weighted score. Expected "
"float or int but got: {score_type}. To exclude the score from the " "float or int but got: {score_type}. To exclude the score from the "
"final score, set its weight to null in the [training.score_weights] " "final score, set its weight to null in the [training.score_weights] "

View File

@ -217,6 +217,7 @@ class ConfigSchemaTraining(BaseModel):
optimizer: Optimizer = Field(..., title="The optimizer to use") optimizer: Optimizer = Field(..., title="The optimizer to use")
logger: Logger = Field(..., title="The logger to track training progress") logger: Logger = Field(..., title="The logger to track training progress")
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training") frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
# fmt: on # fmt: on
class Config: class Config: