small little fixes

This commit is contained in:
svlandeg 2020-06-03 22:17:02 +02:00
parent 07886a3de3
commit 1775f54a26

View File

@ -5,10 +5,9 @@ import re
from collections import Counter from collections import Counter
import plac import plac
from pathlib import Path from pathlib import Path
from thinc.api import Linear, Maxout, chain, list2array from thinc.api import Linear, Maxout, chain, list2array, use_pytorch_for_gpu_memory
from wasabi import msg from wasabi import msg
import srsly import srsly
from thinc.api import use_pytorch_for_gpu_memory
from ..errors import Errors from ..errors import Errors
from ..ml.models.multi_task import build_masked_language_model from ..ml.models.multi_task import build_masked_language_model
@ -73,8 +72,8 @@ def pretrain(
if resume_path: if resume_path:
msg.warn( msg.warn(
"Output directory is not empty. ", "Output directory is not empty. ",
"If you're resuming a run from a previous " "If you're resuming a run from a previous model in this directory, "
"model, the old models for the consecutive epochs will be overwritten " "the old models for the consecutive epochs will be overwritten "
"with the new ones.", "with the new ones.",
) )
else: else:
@ -129,16 +128,18 @@ def pretrain(
else: else:
if not epoch_resume: if not epoch_resume:
msg.fail( msg.fail(
"You have to use the --epoch_resume setting when using a renamed weight file for --resume_path", "You have to use the --epoch-resume setting when using a renamed weight file for --resume-path",
exits=True, exits=True,
) )
elif epoch_resume < 0: elif epoch_resume < 0:
msg.fail( msg.fail(
f"The setting --epoch_resume has to be greater or equal to 0. {epoch_resume} is invalid", f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
exits=True, exits=True,
) )
else:
msg.info(f"Resuming from epoch: {epoch_resume}")
else: else:
# Without 'resume_path' the 'epoch_resume' setting is ignored # Without '--resume-path' the '--epoch-resume' argument is ignored
epoch_resume = 0 epoch_resume = 0
tracker = ProgressTracker(frequency=10000) tracker = ProgressTracker(frequency=10000)