mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-01 03:33:12 +03:00
rename init_tok2vec to resume
This commit is contained in:
parent
4ed6278663
commit
07886a3de3
|
@ -46,7 +46,6 @@ learn_rate = 0.001
|
||||||
|
|
||||||
[pretraining]
|
[pretraining]
|
||||||
max_epochs = 1000
|
max_epochs = 1000
|
||||||
start_epoch = 0
|
|
||||||
min_length = 5
|
min_length = 5
|
||||||
max_length = 500
|
max_length = 500
|
||||||
dropout = 0.2
|
dropout = 0.2
|
||||||
|
@ -54,7 +53,6 @@ n_save_every = null
|
||||||
batch_size = 3000
|
batch_size = 3000
|
||||||
seed = ${training:seed}
|
seed = ${training:seed}
|
||||||
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory}
|
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory}
|
||||||
init_tok2vec = null
|
|
||||||
|
|
||||||
[pretraining.model]
|
[pretraining.model]
|
||||||
@architectures = "spacy.HashEmbedCNN.v1"
|
@architectures = "spacy.HashEmbedCNN.v1"
|
||||||
|
|
|
@ -16,7 +16,6 @@ from ..tokens import Doc
|
||||||
from ..attrs import ID, HEAD
|
from ..attrs import ID, HEAD
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..gold import Example
|
from ..gold import Example
|
||||||
from .deprecated_pretrain import _load_pretrained_tok2vec # TODO
|
|
||||||
|
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
|
@ -26,7 +25,10 @@ from .deprecated_pretrain import _load_pretrained_tok2vec # TODO
|
||||||
output_dir=("Directory to write models to on each epoch", "positional", None, Path),
|
output_dir=("Directory to write models to on each epoch", "positional", None, Path),
|
||||||
config_path=("Path to config file", "positional", None, Path),
|
config_path=("Path to config file", "positional", None, Path),
|
||||||
use_gpu=("Use GPU", "option", "g", int),
|
use_gpu=("Use GPU", "option", "g", int),
|
||||||
# fmt: on
|
resume_path=("Path to pretrained weights from which to resume pretraining", "option","r", Path),
|
||||||
|
epoch_resume=("The epoch to resume counting from when using '--resume_path'. Prevents unintended overwriting of existing weight files.","option", "er", int),
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
)
|
)
|
||||||
def pretrain(
|
def pretrain(
|
||||||
texts_loc,
|
texts_loc,
|
||||||
|
@ -34,6 +36,8 @@ def pretrain(
|
||||||
config_path,
|
config_path,
|
||||||
output_dir,
|
output_dir,
|
||||||
use_gpu=-1,
|
use_gpu=-1,
|
||||||
|
resume_path=None,
|
||||||
|
epoch_resume=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
|
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
|
||||||
|
@ -66,11 +70,19 @@ def pretrain(
|
||||||
use_pytorch_for_gpu_memory()
|
use_pytorch_for_gpu_memory()
|
||||||
|
|
||||||
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
||||||
msg.warn(
|
if resume_path:
|
||||||
"Output directory is not empty",
|
msg.warn(
|
||||||
"It is better to use an empty directory or refer to a new output path, "
|
"Output directory is not empty. ",
|
||||||
"then the new directory will be created for you.",
|
"If you're resuming a run from a previous "
|
||||||
)
|
"model, the old models for the consecutive epochs will be overwritten "
|
||||||
|
"with the new ones.",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
msg.warn(
|
||||||
|
"Output directory is not empty. ",
|
||||||
|
"It is better to use an empty directory or refer to a new output path, "
|
||||||
|
"then the new directory will be created for you.",
|
||||||
|
)
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
msg.good(f"Created output directory: {output_dir}")
|
msg.good(f"Created output directory: {output_dir}")
|
||||||
|
@ -92,7 +104,7 @@ def pretrain(
|
||||||
msg.good("Loaded input texts")
|
msg.good("Loaded input texts")
|
||||||
random.shuffle(texts)
|
random.shuffle(texts)
|
||||||
else: # reading from stdin
|
else: # reading from stdin
|
||||||
msg.text("Reading input text from stdin...")
|
msg.info("Reading input text from stdin...")
|
||||||
texts = srsly.read_jsonl("-")
|
texts = srsly.read_jsonl("-")
|
||||||
|
|
||||||
with msg.loading(f"Loading model '{vectors_model}'..."):
|
with msg.loading(f"Loading model '{vectors_model}'..."):
|
||||||
|
@ -101,35 +113,36 @@ def pretrain(
|
||||||
tok2vec = pretrain_config["model"]
|
tok2vec = pretrain_config["model"]
|
||||||
model = create_pretraining_model(nlp, tok2vec)
|
model = create_pretraining_model(nlp, tok2vec)
|
||||||
optimizer = pretrain_config["optimizer"]
|
optimizer = pretrain_config["optimizer"]
|
||||||
init_tok2vec = pretrain_config["init_tok2vec"]
|
|
||||||
epoch_start = pretrain_config["epoch_start"]
|
|
||||||
|
|
||||||
# Load in pretrained weights - TODO test
|
# Load in pretrained weights to resume from
|
||||||
if init_tok2vec is not None:
|
if resume_path is not None:
|
||||||
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
|
msg.info(f"Resume training tok2vec from: {resume_path}")
|
||||||
msg.text(f"Loaded pretrained tok2vec for: {components}")
|
with resume_path.open("rb") as file_:
|
||||||
|
weights_data = file_.read()
|
||||||
|
model.get_ref("tok2vec").from_bytes(weights_data)
|
||||||
# Parse the epoch number from the given weight file
|
# Parse the epoch number from the given weight file
|
||||||
model_name = re.search(r"model\d+\.bin", str(init_tok2vec))
|
model_name = re.search(r"model\d+\.bin", str(resume_path))
|
||||||
if model_name:
|
if model_name:
|
||||||
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
|
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
|
||||||
epoch_start = int(model_name.group(0)[5:][:-4]) + 1
|
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
|
||||||
|
msg.info(f"Resuming from epoch: {epoch_resume}")
|
||||||
else:
|
else:
|
||||||
if not epoch_start:
|
if not epoch_resume:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
"You have to use the epoch_start setting when using a renamed weight file for init_tok2vec",
|
"You have to use the --epoch_resume setting when using a renamed weight file for --resume_path",
|
||||||
exits=True,
|
exits=True,
|
||||||
)
|
)
|
||||||
elif epoch_start < 0:
|
elif epoch_resume < 0:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"The setting epoch_start has to be greater or equal to 0. {epoch_start} is invalid",
|
f"The setting --epoch_resume has to be greater or equal to 0. {epoch_resume} is invalid",
|
||||||
exits=True,
|
exits=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Without 'init-tok2vec' the 'epoch_start' setting is ignored
|
# Without 'resume_path' the 'epoch_resume' setting is ignored
|
||||||
epoch_start = 0
|
epoch_resume = 0
|
||||||
|
|
||||||
tracker = ProgressTracker(frequency=10000)
|
tracker = ProgressTracker(frequency=10000)
|
||||||
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_start}")
|
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume}")
|
||||||
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
|
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
|
||||||
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
|
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
|
||||||
|
|
||||||
|
@ -149,7 +162,7 @@ def pretrain(
|
||||||
|
|
||||||
skip_counter = 0
|
skip_counter = 0
|
||||||
loss_func = pretrain_config["loss_func"]
|
loss_func = pretrain_config["loss_func"]
|
||||||
for epoch in range(epoch_start, pretrain_config["max_epochs"]):
|
for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
|
||||||
examples = [Example(doc=text) for text in texts]
|
examples = [Example(doc=text) for text in texts]
|
||||||
batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"])
|
batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"])
|
||||||
for batch_id, batch in enumerate(batches):
|
for batch_id, batch in enumerate(batches):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user