From 04982ccc4033ec15864bba659430a8408ca94774 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Bj=C3=B6rn=20B=C3=B6ing?=
<33514570+BreakBB@users.noreply.github.com>
Date: Tue, 9 Jul 2019 21:48:30 +0200
Subject: [PATCH] =?UTF-8?q?Update=20pretrain=20to=20prevent=20unintended?=
=?UTF-8?q?=20overwriting=20of=20weight=20fil=E2=80=A6=20(#3902)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* Update pretrain to prevent unintended overwriting of weight files for #3859
* Add '--epoch-start' to pretrain docs
* Add mising pretrain arguments to bash example
* Update doc tag for v2.1.5
---
spacy/cli/pretrain.py | 33 +++++++++++++++++++++++++++++++--
website/docs/api/cli.md | 9 +++++----
2 files changed, 36 insertions(+), 6 deletions(-)
diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py
index 2fe5b247a..678f12be1 100644
--- a/spacy/cli/pretrain.py
+++ b/spacy/cli/pretrain.py
@@ -5,6 +5,7 @@ import plac
import random
import numpy
import time
+import re
from collections import Counter
from pathlib import Path
from thinc.v2v import Affine, Maxout
@@ -65,6 +66,13 @@ from .train import _load_pretrained_tok2vec
"t2v",
Path,
),
+ epoch_start=(
+ "The epoch to start counting at. Only relevant when using '--init-tok2vec' and the given weight file has been "
+ "renamed. Prevents unintended overwriting of existing weight files.",
+ "option",
+ "es",
+ int
+ ),
)
def pretrain(
texts_loc,
@@ -83,6 +91,7 @@ def pretrain(
seed=0,
n_save_every=None,
init_tok2vec=None,
+ epoch_start=None,
):
"""
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
@@ -151,9 +160,29 @@ def pretrain(
if init_tok2vec is not None:
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
msg.text("Loaded pretrained tok2vec for: {}".format(components))
+ # Parse the epoch number from the given weight file
+ model_name = re.search(r"model\d+\.bin", str(init_tok2vec))
+ if model_name:
+ # Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
+ epoch_start = int(model_name.group(0)[5:][:-4]) + 1
+ else:
+ if not epoch_start:
+ msg.fail(
+ "You have to use the '--epoch-start' argument when using a renamed weight file for "
+ "'--init-tok2vec'", exits=True
+ )
+ elif epoch_start < 0:
+ msg.fail(
+ "The argument '--epoch-start' has to be greater or equal to 0. '%d' is invalid" % epoch_start,
+ exits=True
+ )
+ else:
+ # Without '--init-tok2vec' the '--epoch-start' argument is ignored
+ epoch_start = 0
+
optimizer = create_default_optimizer(model.ops)
tracker = ProgressTracker(frequency=10000)
- msg.divider("Pre-training tok2vec layer")
+ msg.divider("Pre-training tok2vec layer - starting at epoch %d" % epoch_start)
row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")}
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
@@ -174,7 +203,7 @@ def pretrain(
file_.write(srsly.json_dumps(log) + "\n")
skip_counter = 0
- for epoch in range(n_iter):
+ for epoch in range(epoch_start, n_iter + epoch_start):
for batch_id, batch in enumerate(
util.minibatch_by_words(((text, None) for text in texts), size=batch_size)
):
diff --git a/website/docs/api/cli.md b/website/docs/api/cli.md
index a69e62219..7af134e40 100644
--- a/website/docs/api/cli.md
+++ b/website/docs/api/cli.md
@@ -284,9 +284,9 @@ same between pretraining and training. The API and errors around this need some
improvement.
```bash
-$ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir] [--width]
-[--depth] [--embed-rows] [--loss_func] [--dropout] [--seed] [--n-iter] [--use-vectors]
-[--n-save_every]
+$ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir]
+[--width] [--depth] [--embed-rows] [--loss_func] [--dropout] [--batch-size] [--max-length] [--min-length]
+[--seed] [--n-iter] [--use-vectors] [--n-save_every] [--init-tok2vec] [--epoch-start]
```
| Argument | Type | Description |
@@ -306,7 +306,8 @@ $ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir] [--width]
| `--n-iter`, `-i` | option | Number of iterations to pretrain. |
| `--use-vectors`, `-uv` | flag | Whether to use the static vectors as input features. |
| `--n-save-every`, `-se` | option | Save model every X batches. |
-| `--init-tok2vec`, `-t2v` 2.1 | option | Path to pretrained weights for the token-to-vector parts of the models. See `spacy pretrain`. Experimental.|
+| `--init-tok2vec`, `-t2v` 2.1 | option | Path to pretrained weights for the token-to-vector parts of the models. See `spacy pretrain`. Experimental.|
+| `--epoch-start`, `-es` 2.1.5 | option | The epoch to start counting at. Only relevant when using `--init-tok2vec` and the given weight file has been renamed. Prevents unintended overwriting of existing weight files.|
| **CREATES** | weights | The pre-trained weights that can be used to initialize `spacy train`. |
### JSONL format for raw text {#pretrain-jsonl}