mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Add [training.before_to_disk] callback
This commit is contained in:
parent
d7ab6a2ffe
commit
be56c0994b
|
@ -97,6 +97,7 @@ def train(
|
||||||
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
|
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
|
||||||
batcher = T_cfg["batcher"]
|
batcher = T_cfg["batcher"]
|
||||||
train_logger = T_cfg["logger"]
|
train_logger = T_cfg["logger"]
|
||||||
|
before_to_disk = create_before_to_disk_callback(T_cfg["before_to_disk"])
|
||||||
# Components that shouldn't be updated during training
|
# Components that shouldn't be updated during training
|
||||||
frozen_components = T_cfg["frozen_components"]
|
frozen_components = T_cfg["frozen_components"]
|
||||||
# Sourced components that require resume_training
|
# Sourced components that require resume_training
|
||||||
|
@ -167,6 +168,7 @@ def train(
|
||||||
with nlp.select_pipes(disable=frozen_components):
|
with nlp.select_pipes(disable=frozen_components):
|
||||||
update_meta(T_cfg, nlp, info)
|
update_meta(T_cfg, nlp, info)
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
|
nlp = before_to_disk(nlp)
|
||||||
nlp.to_disk(output_path / "model-best")
|
nlp.to_disk(output_path / "model-best")
|
||||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
||||||
progress.set_description(f"Epoch {info['epoch']}")
|
progress.set_description(f"Epoch {info['epoch']}")
|
||||||
|
@ -179,6 +181,7 @@ def train(
|
||||||
f"Aborting and saving the final best model. "
|
f"Aborting and saving the final best model. "
|
||||||
f"Encountered exception: {str(e)}"
|
f"Encountered exception: {str(e)}"
|
||||||
)
|
)
|
||||||
|
nlp = before_to_disk(nlp)
|
||||||
nlp.to_disk(output_path / "model-final")
|
nlp.to_disk(output_path / "model-final")
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
|
@ -233,6 +236,21 @@ def create_evaluation_callback(
|
||||||
return evaluate
|
return evaluate
|
||||||
|
|
||||||
|
|
||||||
|
def create_before_to_disk_callback(
|
||||||
|
callback: Optional[Callable[[Language], Language]]
|
||||||
|
) -> Callable[[Language], Language]:
|
||||||
|
def before_to_disk(nlp: Language) -> Language:
|
||||||
|
if not callback:
|
||||||
|
return nlp
|
||||||
|
modified_nlp = callback(nlp)
|
||||||
|
if not isinstance(modified_nlp, Language):
|
||||||
|
err = Errors.E914.format(name="before_to_disk", value=type(modified_nlp))
|
||||||
|
raise ValueError(err)
|
||||||
|
return modified_nlp
|
||||||
|
|
||||||
|
return before_to_disk
|
||||||
|
|
||||||
|
|
||||||
def train_while_improving(
|
def train_while_improving(
|
||||||
nlp: Language,
|
nlp: Language,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
|
|
|
@ -72,6 +72,8 @@ frozen_components = []
|
||||||
dev_corpus = "corpora.dev"
|
dev_corpus = "corpora.dev"
|
||||||
# Location in the config where the train corpus is defined
|
# Location in the config where the train corpus is defined
|
||||||
train_corpus = "corpora.train"
|
train_corpus = "corpora.train"
|
||||||
|
# Optional callback before nlp object is saved to disk after training
|
||||||
|
before_to_disk = null
|
||||||
|
|
||||||
[training.logger]
|
[training.logger]
|
||||||
@loggers = "spacy.ConsoleLogger.v1"
|
@loggers = "spacy.ConsoleLogger.v1"
|
||||||
|
|
|
@ -480,6 +480,9 @@ class Errors:
|
||||||
E201 = ("Span index out of range.")
|
E201 = ("Span index out of range.")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E914 = ("Executing {name} callback failed. Expected the function to "
|
||||||
|
"returnthe nlp object but got: {value}. Maybe you forgot to return "
|
||||||
|
"the modified object in your function?")
|
||||||
E915 = ("Can't use score '{name}' to calculate final weighted score. Expected "
|
E915 = ("Can't use score '{name}' to calculate final weighted score. Expected "
|
||||||
"float or int but got: {score_type}. To exclude the score from the "
|
"float or int but got: {score_type}. To exclude the score from the "
|
||||||
"final score, set its weight to null in the [training.score_weights] "
|
"final score, set its weight to null in the [training.score_weights] "
|
||||||
|
|
|
@ -217,6 +217,7 @@ class ConfigSchemaTraining(BaseModel):
|
||||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||||
logger: Logger = Field(..., title="The logger to track training progress")
|
logger: Logger = Field(..., title="The logger to track training progress")
|
||||||
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
||||||
|
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user