mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 21:24:11 +03:00
Expose batch size and length caps on CLI for pretrain (#3417)
Add and document CLI options for batch size, max doc length, min doc length for `spacy pretrain`. Also improve CLI output. Closes #3216 ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
This commit is contained in:
parent
58d562d9b0
commit
62afa64a8d
|
@ -29,6 +29,9 @@ from .. import util
|
|||
embed_rows=("Embedding rows", "option", "er", int),
|
||||
use_vectors=("Whether to use the static vectors as input features", "flag", "uv"),
|
||||
dropout=("Dropout", "option", "d", float),
|
||||
batch_size=("Number of words per training batch", "option", "bs", int),
|
||||
max_length=("Max words per example.", "option", "xw", int),
|
||||
min_length=("Min words per example.", "option", "nw", int),
|
||||
seed=("Seed for random number generators", "option", "s", float),
|
||||
nr_iter=("Number of iterations to pretrain", "option", "i", int),
|
||||
)
|
||||
|
@ -42,6 +45,9 @@ def pretrain(
|
|||
use_vectors=False,
|
||||
dropout=0.2,
|
||||
nr_iter=1000,
|
||||
batch_size=3000,
|
||||
max_length=500,
|
||||
min_length=5,
|
||||
seed=0,
|
||||
):
|
||||
"""
|
||||
|
@ -109,9 +115,14 @@ def pretrain(
|
|||
msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
|
||||
for epoch in range(nr_iter):
|
||||
for batch in util.minibatch_by_words(
|
||||
((text, None) for text in texts), size=3000
|
||||
((text, None) for text in texts), size=batch_size
|
||||
):
|
||||
docs = make_docs(nlp, [text for (text, _) in batch])
|
||||
docs = make_docs(
|
||||
nlp,
|
||||
[text for (text, _) in batch],
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
)
|
||||
loss = make_update(model, docs, optimizer, drop=dropout)
|
||||
progress = tracker.update(epoch, loss, docs)
|
||||
if progress:
|
||||
|
@ -152,7 +163,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
|
|||
return float(loss)
|
||||
|
||||
|
||||
def make_docs(nlp, batch, min_length=1, max_length=500):
|
||||
def make_docs(nlp, batch, min_length, max_length):
|
||||
docs = []
|
||||
for record in batch:
|
||||
text = record["text"]
|
||||
|
@ -241,11 +252,23 @@ class ProgressTracker(object):
|
|||
status = (
|
||||
epoch,
|
||||
self.nr_word,
|
||||
"%.8f" % self.loss,
|
||||
"%.8f" % loss_per_word,
|
||||
_smart_round(self.loss, width=10),
|
||||
_smart_round(loss_per_word, width=6),
|
||||
int(wps),
|
||||
)
|
||||
self.prev_loss = float(self.loss)
|
||||
return status
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _smart_round(figure, width=10, max_decimal=4):
|
||||
"""Round large numbers as integers, smaller numbers as decimals."""
|
||||
n_digits = len(str(int(figure)))
|
||||
n_decimal = width - (n_digits + 1)
|
||||
if n_decimal <= 1:
|
||||
return str(int(figure))
|
||||
else:
|
||||
n_decimal = min(n_decimal, max_decimal)
|
||||
format_str = "%." + str(n_decimal) + "f"
|
||||
return format_str % figure
|
||||
|
|
|
@ -296,6 +296,9 @@ $ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir] [--width]
|
|||
| `--depth`, `-cd` | option | Depth of CNN layers. |
|
||||
| `--embed-rows`, `-er` | option | Number of embedding rows. |
|
||||
| `--dropout`, `-d` | option | Dropout rate. |
|
||||
| `--batch-size`, `-bs` | option | Number of words per training batch. |
|
||||
| `--max-length`, `-xw` | option | Maximum words per example. Longer examples are discarded. |
|
||||
| `--min-length`, `-nw` | option | Minimum words per example. Shorter examples are discarded. |
|
||||
| `--seed`, `-s` | option | Seed for random number generators. |
|
||||
| `--n-iter`, `-i` | option | Number of iterations to pretrain. |
|
||||
| `--use-vectors`, `-uv` | flag | Whether to use the static vectors as input features. |
|
||||
|
|
Loading…
Reference in New Issue
Block a user