small fixes to pretrain config, init_tok2vec TODO

This commit is contained in:
svlandeg 2020-06-03 19:32:40 +02:00
parent ddf8244df9
commit 4ed6278663
3 changed files with 36 additions and 6 deletions

View File

@ -45,12 +45,16 @@ eps = 1e-8
learn_rate = 0.001 learn_rate = 0.001
[pretraining] [pretraining]
max_epochs = 100 max_epochs = 1000
start_epoch = 0
min_length = 5 min_length = 5
max_length = 500 max_length = 500
dropout = 0.2 dropout = 0.2
n_save_every = null n_save_every = null
batch_size = 3000 batch_size = 3000
seed = ${training:seed}
use_pytorch_for_gpu_memory = ${training:use_pytorch_for_gpu_memory}
init_tok2vec = null
[pretraining.model] [pretraining.model]
@architectures = "spacy.HashEmbedCNN.v1" @architectures = "spacy.HashEmbedCNN.v1"

View File

@ -16,14 +16,15 @@ from ..tokens import Doc
from ..attrs import ID, HEAD from ..attrs import ID, HEAD
from .. import util from .. import util
from ..gold import Example from ..gold import Example
from .deprecated_pretrain import _load_pretrained_tok2vec # TODO
@plac.annotations( @plac.annotations(
# fmt: off # fmt: off
texts_loc=("Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", "positional", None, str), texts_loc=("Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", "positional", None, str),
vectors_model=("Name or path to spaCy model with vectors to learn from", "positional", None, str), vectors_model=("Name or path to spaCy model with vectors to learn from", "positional", None, str),
config_path=("Path to config file", "positional", None, Path),
output_dir=("Directory to write models to on each epoch", "positional", None, Path), output_dir=("Directory to write models to on each epoch", "positional", None, Path),
config_path=("Path to config file", "positional", None, Path),
use_gpu=("Use GPU", "option", "g", int), use_gpu=("Use GPU", "option", "g", int),
# fmt: on # fmt: on
) )
@ -60,8 +61,8 @@ def pretrain(
msg.info(f"Loading config from: {config_path}") msg.info(f"Loading config from: {config_path}")
config = util.load_config(config_path, create_objects=False) config = util.load_config(config_path, create_objects=False)
util.fix_random_seed(config["training"]["seed"]) util.fix_random_seed(config["pretraining"]["seed"])
if config["training"]["use_pytorch_for_gpu_memory"]: if config["pretraining"]["use_pytorch_for_gpu_memory"]:
use_pytorch_for_gpu_memory() use_pytorch_for_gpu_memory()
if output_dir.exists() and [p for p in output_dir.iterdir()]: if output_dir.exists() and [p for p in output_dir.iterdir()]:
@ -100,8 +101,33 @@ def pretrain(
tok2vec = pretrain_config["model"] tok2vec = pretrain_config["model"]
model = create_pretraining_model(nlp, tok2vec) model = create_pretraining_model(nlp, tok2vec)
optimizer = pretrain_config["optimizer"] optimizer = pretrain_config["optimizer"]
init_tok2vec = pretrain_config["init_tok2vec"]
epoch_start = pretrain_config["epoch_start"]
# Load in pretrained weights - TODO test
if init_tok2vec is not None:
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
msg.text(f"Loaded pretrained tok2vec for: {components}")
# Parse the epoch number from the given weight file
model_name = re.search(r"model\d+\.bin", str(init_tok2vec))
if model_name:
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
epoch_start = int(model_name.group(0)[5:][:-4]) + 1
else:
if not epoch_start:
msg.fail(
"You have to use the epoch_start setting when using a renamed weight file for init_tok2vec",
exits=True,
)
elif epoch_start < 0:
msg.fail(
f"The setting epoch_start has to be greater or equal to 0. {epoch_start} is invalid",
exits=True,
)
else:
# Without 'init-tok2vec' the 'epoch_start' setting is ignored
epoch_start = 0
epoch_start = 0 # TODO
tracker = ProgressTracker(frequency=10000) tracker = ProgressTracker(frequency=10000)
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_start}") msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_start}")
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")} row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}