Fix inference of epoch_resume (#9084)

* Fix inference of epoch_resume

When an epoch_resume value is not specified individually, it can often
be inferred from the filename. The value inference code was there but
the value wasn't passed back to the training loop.

This also adds a specific error in the case where no epoch_resume value
is provided and it can't be inferred from the filename.

* Add new error

* Always use the epoch resume value if specified

Before this the value in the filename was used if found
This commit is contained in:
Paul O'Leary McCann 2021-09-01 14:17:42 +09:00 committed by GitHub
parent a17b06d18b
commit f803a84571
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 10 deletions

View File

@ -869,6 +869,8 @@ class Errors:
E1019 = ("`noun_chunks` requires the pos tagging, which requires a "
"statistical model to be installed and loaded. For more info, see "
"the documentation:\nhttps://spacy.io/usage/models")
E1020 = ("No `epoch_resume` value specified and could not infer one from "
"filename. Specify an epoch to resume from.")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -41,10 +41,11 @@ def pretrain(
optimizer = P["optimizer"]
# Load in pretrained weights to resume from
if resume_path is not None:
_resume_model(model, resume_path, epoch_resume, silent=silent)
epoch_resume = _resume_model(model, resume_path, epoch_resume, silent=silent)
else:
# Without '--resume-path' the '--epoch-resume' argument is ignored
epoch_resume = 0
objective = model.attrs["loss"]
# TODO: move this to logger function?
tracker = ProgressTracker(frequency=10000)
@ -93,20 +94,25 @@ def ensure_docs(examples_or_docs: Iterable[Union[Doc, Example]]) -> List[Doc]:
def _resume_model(
model: Model, resume_path: Path, epoch_resume: int, silent: bool = True
) -> None:
) -> int:
msg = Printer(no_print=silent)
msg.info(f"Resume training tok2vec from: {resume_path}")
with resume_path.open("rb") as file_:
weights_data = file_.read()
model.get_ref("tok2vec").from_bytes(weights_data)
# Parse the epoch number from the given weight file
model_name = re.search(r"model\d+\.bin", str(resume_path))
if model_name:
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
msg.info(f"Resuming from epoch: {epoch_resume}")
else:
msg.info(f"Resuming from epoch: {epoch_resume}")
if epoch_resume is None:
# Parse the epoch number from the given weight file
model_name = re.search(r"model\d+\.bin", str(resume_path))
if model_name:
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
else:
# No epoch given and couldn't infer it
raise ValueError(Errors.E1020)
msg.info(f"Resuming from epoch: {epoch_resume}")
return epoch_resume
def make_update(