mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-05 13:43:24 +03:00
Expose batch size and length caps on CLI for pretrain (#3417)
Add and document CLI options for batch size, max doc length, min doc length for `spacy pretrain`. Also improve CLI output. Closes #3216 ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
This commit is contained in:
parent
58d562d9b0
commit
62afa64a8d
|
@ -29,6 +29,9 @@ from .. import util
|
||||||
embed_rows=("Embedding rows", "option", "er", int),
|
embed_rows=("Embedding rows", "option", "er", int),
|
||||||
use_vectors=("Whether to use the static vectors as input features", "flag", "uv"),
|
use_vectors=("Whether to use the static vectors as input features", "flag", "uv"),
|
||||||
dropout=("Dropout", "option", "d", float),
|
dropout=("Dropout", "option", "d", float),
|
||||||
|
batch_size=("Number of words per training batch", "option", "bs", int),
|
||||||
|
max_length=("Max words per example.", "option", "xw", int),
|
||||||
|
min_length=("Min words per example.", "option", "nw", int),
|
||||||
seed=("Seed for random number generators", "option", "s", float),
|
seed=("Seed for random number generators", "option", "s", float),
|
||||||
nr_iter=("Number of iterations to pretrain", "option", "i", int),
|
nr_iter=("Number of iterations to pretrain", "option", "i", int),
|
||||||
)
|
)
|
||||||
|
@ -42,6 +45,9 @@ def pretrain(
|
||||||
use_vectors=False,
|
use_vectors=False,
|
||||||
dropout=0.2,
|
dropout=0.2,
|
||||||
nr_iter=1000,
|
nr_iter=1000,
|
||||||
|
batch_size=3000,
|
||||||
|
max_length=500,
|
||||||
|
min_length=5,
|
||||||
seed=0,
|
seed=0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -109,9 +115,14 @@ def pretrain(
|
||||||
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
|
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
|
||||||
for epoch in range(nr_iter):
|
for epoch in range(nr_iter):
|
||||||
for batch in util.minibatch_by_words(
|
for batch in util.minibatch_by_words(
|
||||||
((text, None) for text in texts), size=3000
|
((text, None) for text in texts), size=batch_size
|
||||||
):
|
):
|
||||||
docs = make_docs(nlp, [text for (text, _) in batch])
|
docs = make_docs(
|
||||||
|
nlp,
|
||||||
|
[text for (text, _) in batch],
|
||||||
|
max_length=max_length,
|
||||||
|
min_length=min_length,
|
||||||
|
)
|
||||||
loss = make_update(model, docs, optimizer, drop=dropout)
|
loss = make_update(model, docs, optimizer, drop=dropout)
|
||||||
progress = tracker.update(epoch, loss, docs)
|
progress = tracker.update(epoch, loss, docs)
|
||||||
if progress:
|
if progress:
|
||||||
|
@ -152,7 +163,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
|
||||||
return float(loss)
|
return float(loss)
|
||||||
|
|
||||||
|
|
||||||
def make_docs(nlp, batch, min_length=1, max_length=500):
|
def make_docs(nlp, batch, min_length, max_length):
|
||||||
docs = []
|
docs = []
|
||||||
for record in batch:
|
for record in batch:
|
||||||
text = record["text"]
|
text = record["text"]
|
||||||
|
@ -241,11 +252,23 @@ class ProgressTracker(object):
|
||||||
status = (
|
status = (
|
||||||
epoch,
|
epoch,
|
||||||
self.nr_word,
|
self.nr_word,
|
||||||
"%.8f" % self.loss,
|
_smart_round(self.loss, width=10),
|
||||||
"%.8f" % loss_per_word,
|
_smart_round(loss_per_word, width=6),
|
||||||
int(wps),
|
int(wps),
|
||||||
)
|
)
|
||||||
self.prev_loss = float(self.loss)
|
self.prev_loss = float(self.loss)
|
||||||
return status
|
return status
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _smart_round(figure, width=10, max_decimal=4):
|
||||||
|
"""Round large numbers as integers, smaller numbers as decimals."""
|
||||||
|
n_digits = len(str(int(figure)))
|
||||||
|
n_decimal = width - (n_digits + 1)
|
||||||
|
if n_decimal <= 1:
|
||||||
|
return str(int(figure))
|
||||||
|
else:
|
||||||
|
n_decimal = min(n_decimal, max_decimal)
|
||||||
|
format_str = "%." + str(n_decimal) + "f"
|
||||||
|
return format_str % figure
|
||||||
|
|
|
@ -296,6 +296,9 @@ $ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir] [--width]
|
||||||
| `--depth`, `-cd` | option | Depth of CNN layers. |
|
| `--depth`, `-cd` | option | Depth of CNN layers. |
|
||||||
| `--embed-rows`, `-er` | option | Number of embedding rows. |
|
| `--embed-rows`, `-er` | option | Number of embedding rows. |
|
||||||
| `--dropout`, `-d` | option | Dropout rate. |
|
| `--dropout`, `-d` | option | Dropout rate. |
|
||||||
|
| `--batch-size`, `-bs` | option | Number of words per training batch. |
|
||||||
|
| `--max-length`, `-xw` | option | Maximum words per example. Longer examples are discarded. |
|
||||||
|
| `--min-length`, `-nw` | option | Minimum words per example. Shorter examples are discarded. |
|
||||||
| `--seed`, `-s` | option | Seed for random number generators. |
|
| `--seed`, `-s` | option | Seed for random number generators. |
|
||||||
| `--n-iter`, `-i` | option | Number of iterations to pretrain. |
|
| `--n-iter`, `-i` | option | Number of iterations to pretrain. |
|
||||||
| `--use-vectors`, `-uv` | flag | Whether to use the static vectors as input features. |
|
| `--use-vectors`, `-uv` | flag | Whether to use the static vectors as input features. |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user