Update pretrain to prevent unintended overwriting of weight fil… (#3902)

* Update pretrain to prevent unintended overwriting of weight files for #3859

* Add '--epoch-start' to pretrain docs

* Add mising pretrain arguments to bash example

* Update doc tag for v2.1.5
This commit is contained in:
Björn Böing 2019-07-09 21:48:30 +02:00 committed by Ines Montani
parent 6d577f0b92
commit 04982ccc40
2 changed files with 36 additions and 6 deletions

View File

@ -5,6 +5,7 @@ import plac
import random
import numpy
import time
import re
from collections import Counter
from pathlib import Path
from thinc.v2v import Affine, Maxout
@ -65,6 +66,13 @@ from .train import _load_pretrained_tok2vec
"t2v",
Path,
),
epoch_start=(
"The epoch to start counting at. Only relevant when using '--init-tok2vec' and the given weight file has been "
"renamed. Prevents unintended overwriting of existing weight files.",
"option",
"es",
int
),
)
def pretrain(
texts_loc,
@ -83,6 +91,7 @@ def pretrain(
seed=0,
n_save_every=None,
init_tok2vec=None,
epoch_start=None,
):
"""
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
@ -151,9 +160,29 @@ def pretrain(
if init_tok2vec is not None:
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
msg.text("Loaded pretrained tok2vec for: {}".format(components))
# Parse the epoch number from the given weight file
model_name = re.search(r"model\d+\.bin", str(init_tok2vec))
if model_name:
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
epoch_start = int(model_name.group(0)[5:][:-4]) + 1
else:
if not epoch_start:
msg.fail(
"You have to use the '--epoch-start' argument when using a renamed weight file for "
"'--init-tok2vec'", exits=True
)
elif epoch_start < 0:
msg.fail(
"The argument '--epoch-start' has to be greater or equal to 0. '%d' is invalid" % epoch_start,
exits=True
)
else:
# Without '--init-tok2vec' the '--epoch-start' argument is ignored
epoch_start = 0
optimizer = create_default_optimizer(model.ops)
tracker = ProgressTracker(frequency=10000)
msg.divider("Pre-training tok2vec layer")
msg.divider("Pre-training tok2vec layer - starting at epoch %d" % epoch_start)
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
@ -174,7 +203,7 @@ def pretrain(
file_.write(srsly.json_dumps(log) + "\n")
skip_counter = 0
for epoch in range(n_iter):
for epoch in range(epoch_start, n_iter + epoch_start):
for batch_id, batch in enumerate(
util.minibatch_by_words(((text, None) for text in texts), size=batch_size)
):

View File

@ -284,9 +284,9 @@ same between pretraining and training. The API and errors around this need some
improvement.
```bash
$ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir] [--width]
[--depth] [--embed-rows] [--loss_func] [--dropout] [--seed] [--n-iter] [--use-vectors]
[--n-save_every]
$ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir]
[--width] [--depth] [--embed-rows] [--loss_func] [--dropout] [--batch-size] [--max-length] [--min-length]
[--seed] [--n-iter] [--use-vectors] [--n-save_every] [--init-tok2vec] [--epoch-start]
```
| Argument | Type | Description |
@ -306,7 +306,8 @@ $ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir] [--width]
| `--n-iter`, `-i` | option | Number of iterations to pretrain. |
| `--use-vectors`, `-uv` | flag | Whether to use the static vectors as input features. |
| `--n-save-every`, `-se` | option | Save model every X batches. |
| `--init-tok2vec`, `-t2v` <Tag variant="new">2.1</Tag> | option | Path to pretrained weights for the token-to-vector parts of the models. See `spacy pretrain`. Experimental.|
| `--init-tok2vec`, `-t2v` <Tag variant="new">2.1</Tag> | option | Path to pretrained weights for the token-to-vector parts of the models. See `spacy pretrain`. Experimental.|
| `--epoch-start`, `-es` <Tag variant="new">2.1.5</Tag> | option | The epoch to start counting at. Only relevant when using `--init-tok2vec` and the given weight file has been renamed. Prevents unintended overwriting of existing weight files.|
| **CREATES** | weights | The pre-trained weights that can be used to initialize `spacy train`. |
### JSONL format for raw text {#pretrain-jsonl}