From 8d50a8009c0161ebd056f3909e6a96acab6bf1c1 Mon Sep 17 00:00:00 2001 From: thomashacker Date: Thu, 23 Mar 2023 15:04:35 +0100 Subject: [PATCH] Adjust pretrain command --- spacy/training/pretrain.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/spacy/training/pretrain.py b/spacy/training/pretrain.py index 52af84aaf..8aae90f44 100644 --- a/spacy/training/pretrain.py +++ b/spacy/training/pretrain.py @@ -60,10 +60,14 @@ def pretrain( row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")} msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings) - def _save_model(epoch, is_temp=False): + def _save_model(epoch, is_temp=False, is_latest=False): is_temp_str = ".temp" if is_temp else "" with model.use_params(optimizer.averages): - with (output_dir / f"model{epoch}{is_temp_str}.bin").open("wb") as file_: + if is_latest: + save_path = output_dir / f"model_latest.bin" + else: + save_path = output_dir / f"model{epoch}{is_temp_str}.bin" + with (save_path).open("wb") as file_: file_.write(model.get_ref("tok2vec").to_bytes()) log = { "nr_word": tracker.nr_word, @@ -86,7 +90,9 @@ def pretrain( if P["n_save_every"] and (batch_id % P["n_save_every"] == 0): _save_model(epoch, is_temp=True) - if P["n_save_epoch"]: + if epoch + 1 == P["max_epochs"]: + _save_model(epoch, is_latest=True) + elif P["n_save_epoch"]: if epoch % P["n_save_epoch"] == 0 or epoch == P["max_epochs"] - 1: _save_model(epoch) else: