From f803a8457177f9fb852467f1692f0db1547b58a3 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Wed, 1 Sep 2021 14:17:42 +0900 Subject: [PATCH] Fix inference of epoch_resume (#9084) * Fix inference of epoch_resume When an epoch_resume value is not specified individually, it can often be inferred from the filename. The value inference code was there but the value wasn't passed back to the training loop. This also adds a specific error in the case where no epoch_resume value is provided and it can't be inferred from the filename. * Add new error * Always use the epoch resume value if specified Before this the value in the filename was used if found --- spacy/errors.py | 2 ++ spacy/training/pretrain.py | 26 ++++++++++++++++---------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index a206826ff..0e1a294c3 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -869,6 +869,8 @@ class Errors: E1019 = ("`noun_chunks` requires the pos tagging, which requires a " "statistical model to be installed and loaded. For more info, see " "the documentation:\nhttps://spacy.io/usage/models") + E1020 = ("No `epoch_resume` value specified and could not infer one from " + "filename. Specify an epoch to resume from.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/training/pretrain.py b/spacy/training/pretrain.py index 6d7850212..0228f2947 100644 --- a/spacy/training/pretrain.py +++ b/spacy/training/pretrain.py @@ -41,10 +41,11 @@ def pretrain( optimizer = P["optimizer"] # Load in pretrained weights to resume from if resume_path is not None: - _resume_model(model, resume_path, epoch_resume, silent=silent) + epoch_resume = _resume_model(model, resume_path, epoch_resume, silent=silent) else: # Without '--resume-path' the '--epoch-resume' argument is ignored epoch_resume = 0 + objective = model.attrs["loss"] # TODO: move this to logger function? tracker = ProgressTracker(frequency=10000) @@ -93,20 +94,25 @@ def ensure_docs(examples_or_docs: Iterable[Union[Doc, Example]]) -> List[Doc]: def _resume_model( model: Model, resume_path: Path, epoch_resume: int, silent: bool = True -) -> None: +) -> int: msg = Printer(no_print=silent) msg.info(f"Resume training tok2vec from: {resume_path}") with resume_path.open("rb") as file_: weights_data = file_.read() model.get_ref("tok2vec").from_bytes(weights_data) - # Parse the epoch number from the given weight file - model_name = re.search(r"model\d+\.bin", str(resume_path)) - if model_name: - # Default weight file name so read epoch_start from it by cutting off 'model' and '.bin' - epoch_resume = int(model_name.group(0)[5:][:-4]) + 1 - msg.info(f"Resuming from epoch: {epoch_resume}") - else: - msg.info(f"Resuming from epoch: {epoch_resume}") + + if epoch_resume is None: + # Parse the epoch number from the given weight file + model_name = re.search(r"model\d+\.bin", str(resume_path)) + if model_name: + # Default weight file name so read epoch_start from it by cutting off 'model' and '.bin' + epoch_resume = int(model_name.group(0)[5:][:-4]) + 1 + else: + # No epoch given and couldn't infer it + raise ValueError(Errors.E1020) + + msg.info(f"Resuming from epoch: {epoch_resume}") + return epoch_resume def make_update(