diff --git a/spacy/training/pretrain.py b/spacy/training/pretrain.py index 8aae90f44..1d0533ff5 100644 --- a/spacy/training/pretrain.py +++ b/spacy/training/pretrain.py @@ -60,11 +60,11 @@ def pretrain( row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")} msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings) - def _save_model(epoch, is_temp=False, is_latest=False): + def _save_model(epoch, is_temp=False, is_last=False): is_temp_str = ".temp" if is_temp else "" with model.use_params(optimizer.averages): - if is_latest: - save_path = output_dir / f"model_latest.bin" + if is_last: + save_path = output_dir / f"model_last.bin" else: save_path = output_dir / f"model{epoch}{is_temp_str}.bin" with (save_path).open("wb") as file_: @@ -81,23 +81,24 @@ def pretrain( # TODO: I think we probably want this to look more like the # 'create_train_batches' function? for epoch in range(epoch_resume, P["max_epochs"]): - for batch_id, batch in enumerate(batcher(corpus(nlp))): - docs = ensure_docs(batch) - loss = make_update(model, docs, optimizer, objective) - progress = tracker.update(epoch, loss, docs) - if progress: - msg.row(progress, **row_settings) - if P["n_save_every"] and (batch_id % P["n_save_every"] == 0): - _save_model(epoch, is_temp=True) + try: + for batch_id, batch in enumerate(batcher(corpus(nlp))): + docs = ensure_docs(batch) + loss = make_update(model, docs, optimizer, objective) + progress = tracker.update(epoch, loss, docs) + if progress: + msg.row(progress, **row_settings) + if P["n_save_every"] and (batch_id % P["n_save_every"] == 0): + _save_model(epoch, is_temp=True) - if epoch + 1 == P["max_epochs"]: - _save_model(epoch, is_latest=True) - elif P["n_save_epoch"]: - if epoch % P["n_save_epoch"] == 0 or epoch == P["max_epochs"] - 1: + if P["n_save_epoch"]: + if epoch % P["n_save_epoch"] == 0 or epoch == P["max_epochs"] - 1: + _save_model(epoch) + else: _save_model(epoch) - else: - _save_model(epoch) - tracker.epoch_loss = 0.0 + tracker.epoch_loss = 0.0 + finally: + _save_model(epoch, is_last=True) def ensure_docs(examples_or_docs: Iterable[Union[Doc, Example]]) -> List[Doc]: