mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Add [training.before_to_disk] callback
This commit is contained in:
parent
d7ab6a2ffe
commit
be56c0994b
|
@ -97,6 +97,7 @@ def train(
|
|||
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
|
||||
batcher = T_cfg["batcher"]
|
||||
train_logger = T_cfg["logger"]
|
||||
before_to_disk = create_before_to_disk_callback(T_cfg["before_to_disk"])
|
||||
# Components that shouldn't be updated during training
|
||||
frozen_components = T_cfg["frozen_components"]
|
||||
# Sourced components that require resume_training
|
||||
|
@ -167,6 +168,7 @@ def train(
|
|||
with nlp.select_pipes(disable=frozen_components):
|
||||
update_meta(T_cfg, nlp, info)
|
||||
with nlp.use_params(optimizer.averages):
|
||||
nlp = before_to_disk(nlp)
|
||||
nlp.to_disk(output_path / "model-best")
|
||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
||||
progress.set_description(f"Epoch {info['epoch']}")
|
||||
|
@ -179,6 +181,7 @@ def train(
|
|||
f"Aborting and saving the final best model. "
|
||||
f"Encountered exception: {str(e)}"
|
||||
)
|
||||
nlp = before_to_disk(nlp)
|
||||
nlp.to_disk(output_path / "model-final")
|
||||
raise e
|
||||
finally:
|
||||
|
@ -233,6 +236,21 @@ def create_evaluation_callback(
|
|||
return evaluate
|
||||
|
||||
|
||||
def create_before_to_disk_callback(
|
||||
callback: Optional[Callable[[Language], Language]]
|
||||
) -> Callable[[Language], Language]:
|
||||
def before_to_disk(nlp: Language) -> Language:
|
||||
if not callback:
|
||||
return nlp
|
||||
modified_nlp = callback(nlp)
|
||||
if not isinstance(modified_nlp, Language):
|
||||
err = Errors.E914.format(name="before_to_disk", value=type(modified_nlp))
|
||||
raise ValueError(err)
|
||||
return modified_nlp
|
||||
|
||||
return before_to_disk
|
||||
|
||||
|
||||
def train_while_improving(
|
||||
nlp: Language,
|
||||
optimizer: Optimizer,
|
||||
|
|
|
@ -72,6 +72,8 @@ frozen_components = []
|
|||
dev_corpus = "corpora.dev"
|
||||
# Location in the config where the train corpus is defined
|
||||
train_corpus = "corpora.train"
|
||||
# Optional callback before nlp object is saved to disk after training
|
||||
before_to_disk = null
|
||||
|
||||
[training.logger]
|
||||
@loggers = "spacy.ConsoleLogger.v1"
|
||||
|
|
|
@ -480,6 +480,9 @@ class Errors:
|
|||
E201 = ("Span index out of range.")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
E914 = ("Executing {name} callback failed. Expected the function to "
|
||||
"returnthe nlp object but got: {value}. Maybe you forgot to return "
|
||||
"the modified object in your function?")
|
||||
E915 = ("Can't use score '{name}' to calculate final weighted score. Expected "
|
||||
"float or int but got: {score_type}. To exclude the score from the "
|
||||
"final score, set its weight to null in the [training.score_weights] "
|
||||
|
|
|
@ -217,6 +217,7 @@ class ConfigSchemaTraining(BaseModel):
|
|||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||
logger: Logger = Field(..., title="The logger to track training progress")
|
||||
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
||||
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
|
|
Loading…
Reference in New Issue
Block a user