mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
rename init_tok2vec to resume
This commit is contained in:
parent
4ed6278663
commit
07886a3de3
|
@ -46,7 +46,6 @@ learn_rate = 0.001
|
|||
|
||||
[pretraining]
|
||||
max_epochs = 1000
|
||||
start_epoch = 0
|
||||
min_length = 5
|
||||
max_length = 500
|
||||
dropout = 0.2
|
||||
|
@ -54,7 +53,6 @@ n_save_every = null
|
|||
batch_size = 3000
|
||||
seed = ${training:seed}
|
||||
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory}
|
||||
init_tok2vec = null
|
||||
|
||||
[pretraining.model]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
|
|
|
@ -16,7 +16,6 @@ from ..tokens import Doc
|
|||
from ..attrs import ID, HEAD
|
||||
from .. import util
|
||||
from ..gold import Example
|
||||
from .deprecated_pretrain import _load_pretrained_tok2vec # TODO
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
|
@ -26,6 +25,9 @@ from .deprecated_pretrain import _load_pretrained_tok2vec # TODO
|
|||
output_dir=("Directory to write models to on each epoch", "positional", None, Path),
|
||||
config_path=("Path to config file", "positional", None, Path),
|
||||
use_gpu=("Use GPU", "option", "g", int),
|
||||
resume_path=("Path to pretrained weights from which to resume pretraining", "option","r", Path),
|
||||
epoch_resume=("The epoch to resume counting from when using '--resume_path'. Prevents unintended overwriting of existing weight files.","option", "er", int),
|
||||
|
||||
# fmt: on
|
||||
)
|
||||
def pretrain(
|
||||
|
@ -34,6 +36,8 @@ def pretrain(
|
|||
config_path,
|
||||
output_dir,
|
||||
use_gpu=-1,
|
||||
resume_path=None,
|
||||
epoch_resume=None,
|
||||
):
|
||||
"""
|
||||
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
|
||||
|
@ -66,8 +70,16 @@ def pretrain(
|
|||
use_pytorch_for_gpu_memory()
|
||||
|
||||
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
||||
if resume_path:
|
||||
msg.warn(
|
||||
"Output directory is not empty",
|
||||
"Output directory is not empty. ",
|
||||
"If you're resuming a run from a previous "
|
||||
"model, the old models for the consecutive epochs will be overwritten "
|
||||
"with the new ones.",
|
||||
)
|
||||
else:
|
||||
msg.warn(
|
||||
"Output directory is not empty. ",
|
||||
"It is better to use an empty directory or refer to a new output path, "
|
||||
"then the new directory will be created for you.",
|
||||
)
|
||||
|
@ -92,7 +104,7 @@ def pretrain(
|
|||
msg.good("Loaded input texts")
|
||||
random.shuffle(texts)
|
||||
else: # reading from stdin
|
||||
msg.text("Reading input text from stdin...")
|
||||
msg.info("Reading input text from stdin...")
|
||||
texts = srsly.read_jsonl("-")
|
||||
|
||||
with msg.loading(f"Loading model '{vectors_model}'..."):
|
||||
|
@ -101,35 +113,36 @@ def pretrain(
|
|||
tok2vec = pretrain_config["model"]
|
||||
model = create_pretraining_model(nlp, tok2vec)
|
||||
optimizer = pretrain_config["optimizer"]
|
||||
init_tok2vec = pretrain_config["init_tok2vec"]
|
||||
epoch_start = pretrain_config["epoch_start"]
|
||||
|
||||
# Load in pretrained weights - TODO test
|
||||
if init_tok2vec is not None:
|
||||
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
|
||||
msg.text(f"Loaded pretrained tok2vec for: {components}")
|
||||
# Load in pretrained weights to resume from
|
||||
if resume_path is not None:
|
||||
msg.info(f"Resume training tok2vec from: {resume_path}")
|
||||
with resume_path.open("rb") as file_:
|
||||
weights_data = file_.read()
|
||||
model.get_ref("tok2vec").from_bytes(weights_data)
|
||||
# Parse the epoch number from the given weight file
|
||||
model_name = re.search(r"model\d+\.bin", str(init_tok2vec))
|
||||
model_name = re.search(r"model\d+\.bin", str(resume_path))
|
||||
if model_name:
|
||||
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
|
||||
epoch_start = int(model_name.group(0)[5:][:-4]) + 1
|
||||
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
|
||||
msg.info(f"Resuming from epoch: {epoch_resume}")
|
||||
else:
|
||||
if not epoch_start:
|
||||
if not epoch_resume:
|
||||
msg.fail(
|
||||
"You have to use the epoch_start setting when using a renamed weight file for init_tok2vec",
|
||||
"You have to use the --epoch_resume setting when using a renamed weight file for --resume_path",
|
||||
exits=True,
|
||||
)
|
||||
elif epoch_start < 0:
|
||||
elif epoch_resume < 0:
|
||||
msg.fail(
|
||||
f"The setting epoch_start has to be greater or equal to 0. {epoch_start} is invalid",
|
||||
f"The setting --epoch_resume has to be greater or equal to 0. {epoch_resume} is invalid",
|
||||
exits=True,
|
||||
)
|
||||
else:
|
||||
# Without 'init-tok2vec' the 'epoch_start' setting is ignored
|
||||
epoch_start = 0
|
||||
# Without 'resume_path' the 'epoch_resume' setting is ignored
|
||||
epoch_resume = 0
|
||||
|
||||
tracker = ProgressTracker(frequency=10000)
|
||||
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_start}")
|
||||
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume}")
|
||||
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
|
||||
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
|
||||
|
||||
|
@ -149,7 +162,7 @@ def pretrain(
|
|||
|
||||
skip_counter = 0
|
||||
loss_func = pretrain_config["loss_func"]
|
||||
for epoch in range(epoch_start, pretrain_config["max_epochs"]):
|
||||
for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
|
||||
examples = [Example(doc=text) for text in texts]
|
||||
batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"])
|
||||
for batch_id, batch in enumerate(batches):
|
||||
|
|
Loading…
Reference in New Issue
Block a user