Merge pull request #6134 from explosion/feature/training_before_to_disk

This commit is contained in:
Ines Montani 2020-09-24 14:44:11 +02:00 committed by GitHub
commit 74e1f192b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 45 additions and 20 deletions

View File

@ -97,6 +97,7 @@ def train(
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
batcher = T_cfg["batcher"]
train_logger = T_cfg["logger"]
before_to_disk = create_before_to_disk_callback(T_cfg["before_to_disk"])
# Components that shouldn't be updated during training
frozen_components = T_cfg["frozen_components"]
# Sourced components that require resume_training
@ -167,6 +168,7 @@ def train(
with nlp.select_pipes(disable=frozen_components):
update_meta(T_cfg, nlp, info)
with nlp.use_params(optimizer.averages):
nlp = before_to_disk(nlp)
nlp.to_disk(output_path / "model-best")
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
progress.set_description(f"Epoch {info['epoch']}")
@ -179,6 +181,7 @@ def train(
f"Aborting and saving the final best model. "
f"Encountered exception: {str(e)}"
)
nlp = before_to_disk(nlp)
nlp.to_disk(output_path / "model-final")
raise e
finally:
@ -233,6 +236,21 @@ def create_evaluation_callback(
return evaluate
def create_before_to_disk_callback(
callback: Optional[Callable[[Language], Language]]
) -> Callable[[Language], Language]:
def before_to_disk(nlp: Language) -> Language:
if not callback:
return nlp
modified_nlp = callback(nlp)
if not isinstance(modified_nlp, Language):
err = Errors.E914.format(name="before_to_disk", value=type(modified_nlp))
raise ValueError(err)
return modified_nlp
return before_to_disk
def train_while_improving(
nlp: Language,
optimizer: Optimizer,

View File

@ -72,6 +72,8 @@ frozen_components = []
dev_corpus = "corpora.dev"
# Location in the config where the train corpus is defined
train_corpus = "corpora.train"
# Optional callback before nlp object is saved to disk after training
before_to_disk = null
[training.logger]
@loggers = "spacy.ConsoleLogger.v1"

View File

@ -480,6 +480,9 @@ class Errors:
E201 = ("Span index out of range.")
# TODO: fix numbering after merging develop into master
E914 = ("Executing {name} callback failed. Expected the function to "
"return the nlp object but got: {value}. Maybe you forgot to return "
"the modified object in your function?")
E915 = ("Can't use score '{name}' to calculate final weighted score. Expected "
"float or int but got: {score_type}. To exclude the score from the "
"final score, set its weight to null in the [training.score_weights] "

View File

@ -216,6 +216,7 @@ class ConfigSchemaTraining(BaseModel):
optimizer: Optimizer = Field(..., title="The optimizer to use")
logger: Logger = Field(..., title="The logger to track training progress")
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
# fmt: on
class Config:

View File

@ -181,9 +181,10 @@ This section defines settings and controls for the training and evaluation
process that are used when you run [`spacy train`](/api/cli#train).
| Name | Description |
| --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| --------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
| `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ |
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |