Add new parameter for saving every n epoch in pretraining (#8912)

* Add parameter for saving every n epoch

* Add new parameter in schemas

* Add new parameter in default_config

* Adjust schemas

* format code
This commit is contained in:
Edward 2021-08-12 11:14:48 +02:00 committed by GitHub
parent f99d6d5e39
commit 944ad6b1d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 3 deletions

View File

@ -5,6 +5,7 @@ raw_text = null
max_epochs = 1000
dropout = 0.2
n_save_every = null
n_save_epoch = null
component = "tok2vec"
layer = ""
corpus = "corpora.pretrain"

View File

@ -351,7 +351,8 @@ class ConfigSchemaPretrain(BaseModel):
# fmt: off
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
dropout: StrictFloat = Field(..., title="Dropout rate")
n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency")
n_save_every: Optional[StrictInt] = Field(..., title="Saving additional temporary model after n batches within an epoch")
n_save_epoch: Optional[StrictInt] = Field(..., title="Saving model after every n epoch")
optimizer: Optimizer = Field(..., title="The optimizer to use")
corpus: StrictStr = Field(..., title="Path in the config to the training data")
batcher: Batcher = Field(..., title="Batcher for the training data")

View File

@ -48,7 +48,10 @@ def pretrain(
objective = model.attrs["loss"]
# TODO: move this to logger function?
tracker = ProgressTracker(frequency=10000)
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume}")
if P["n_save_epoch"]:
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume} - saving every {P['n_save_epoch']} epoch")
else:
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume}")
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
@ -77,7 +80,12 @@ def pretrain(
msg.row(progress, **row_settings)
if P["n_save_every"] and (batch_id % P["n_save_every"] == 0):
_save_model(epoch, is_temp=True)
_save_model(epoch)
if P["n_save_epoch"]:
if epoch % P["n_save_epoch"] == 0 or epoch == P["max_epochs"] - 1:
_save_model(epoch)
else:
_save_model(epoch)
tracker.epoch_loss = 0.0