mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Add new parameter for saving every n epoch in pretraining (#8912)
* Add parameter for saving every n epoch * Add new parameter in schemas * Add new parameter in default_config * Adjust schemas * format code
This commit is contained in:
parent
f99d6d5e39
commit
944ad6b1d4
|
@ -5,6 +5,7 @@ raw_text = null
|
|||
max_epochs = 1000
|
||||
dropout = 0.2
|
||||
n_save_every = null
|
||||
n_save_epoch = null
|
||||
component = "tok2vec"
|
||||
layer = ""
|
||||
corpus = "corpora.pretrain"
|
||||
|
|
|
@ -351,7 +351,8 @@ class ConfigSchemaPretrain(BaseModel):
|
|||
# fmt: off
|
||||
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
|
||||
dropout: StrictFloat = Field(..., title="Dropout rate")
|
||||
n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency")
|
||||
n_save_every: Optional[StrictInt] = Field(..., title="Saving additional temporary model after n batches within an epoch")
|
||||
n_save_epoch: Optional[StrictInt] = Field(..., title="Saving model after every n epoch")
|
||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||
corpus: StrictStr = Field(..., title="Path in the config to the training data")
|
||||
batcher: Batcher = Field(..., title="Batcher for the training data")
|
||||
|
|
|
@ -48,7 +48,10 @@ def pretrain(
|
|||
objective = model.attrs["loss"]
|
||||
# TODO: move this to logger function?
|
||||
tracker = ProgressTracker(frequency=10000)
|
||||
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume}")
|
||||
if P["n_save_epoch"]:
|
||||
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume} - saving every {P['n_save_epoch']} epoch")
|
||||
else:
|
||||
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume}")
|
||||
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
|
||||
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
|
||||
|
||||
|
@ -77,7 +80,12 @@ def pretrain(
|
|||
msg.row(progress, **row_settings)
|
||||
if P["n_save_every"] and (batch_id % P["n_save_every"] == 0):
|
||||
_save_model(epoch, is_temp=True)
|
||||
_save_model(epoch)
|
||||
|
||||
if P["n_save_epoch"]:
|
||||
if epoch % P["n_save_epoch"] == 0 or epoch == P["max_epochs"] - 1:
|
||||
_save_model(epoch)
|
||||
else:
|
||||
_save_model(epoch)
|
||||
tracker.epoch_loss = 0.0
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user