Adjust pretrain command

This commit is contained in:
thomashacker 2023-03-23 15:04:35 +01:00
parent 28de85737f
commit 8d50a8009c

View File

@ -60,10 +60,14 @@ def pretrain(
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")} row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings) msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
def _save_model(epoch, is_temp=False): def _save_model(epoch, is_temp=False, is_latest=False):
is_temp_str = ".temp" if is_temp else "" is_temp_str = ".temp" if is_temp else ""
with model.use_params(optimizer.averages): with model.use_params(optimizer.averages):
with (output_dir / f"model{epoch}{is_temp_str}.bin").open("wb") as file_: if is_latest:
save_path = output_dir / f"model_latest.bin"
else:
save_path = output_dir / f"model{epoch}{is_temp_str}.bin"
with (save_path).open("wb") as file_:
file_.write(model.get_ref("tok2vec").to_bytes()) file_.write(model.get_ref("tok2vec").to_bytes())
log = { log = {
"nr_word": tracker.nr_word, "nr_word": tracker.nr_word,
@ -86,7 +90,9 @@ def pretrain(
if P["n_save_every"] and (batch_id % P["n_save_every"] == 0): if P["n_save_every"] and (batch_id % P["n_save_every"] == 0):
_save_model(epoch, is_temp=True) _save_model(epoch, is_temp=True)
if P["n_save_epoch"]: if epoch + 1 == P["max_epochs"]:
_save_model(epoch, is_latest=True)
elif P["n_save_epoch"]:
if epoch % P["n_save_epoch"] == 0 or epoch == P["max_epochs"] - 1: if epoch % P["n_save_epoch"] == 0 or epoch == P["max_epochs"] - 1:
_save_model(epoch) _save_model(epoch)
else: else: