diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 2bf4c17c1..0ea895597 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -29,6 +29,9 @@ from .. import util embed_rows=("Embedding rows", "option", "er", int), use_vectors=("Whether to use the static vectors as input features", "flag", "uv"), dropout=("Dropout", "option", "d", float), + batch_size=("Number of words per training batch", "option", "bs", int), + max_length=("Max words per example.", "option", "xw", int), + min_length=("Min words per example.", "option", "nw", int), seed=("Seed for random number generators", "option", "s", float), nr_iter=("Number of iterations to pretrain", "option", "i", int), ) @@ -42,6 +45,9 @@ def pretrain( use_vectors=False, dropout=0.2, nr_iter=1000, + batch_size=3000, + max_length=500, + min_length=5, seed=0, ): """ @@ -109,9 +115,14 @@ def pretrain( msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings) for epoch in range(nr_iter): for batch in util.minibatch_by_words( - ((text, None) for text in texts), size=3000 + ((text, None) for text in texts), size=batch_size ): - docs = make_docs(nlp, [text for (text, _) in batch]) + docs = make_docs( + nlp, + [text for (text, _) in batch], + max_length=max_length, + min_length=min_length, + ) loss = make_update(model, docs, optimizer, drop=dropout) progress = tracker.update(epoch, loss, docs) if progress: @@ -152,7 +163,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"): return float(loss) -def make_docs(nlp, batch, min_length=1, max_length=500): +def make_docs(nlp, batch, min_length, max_length): docs = [] for record in batch: text = record["text"] @@ -241,11 +252,23 @@ class ProgressTracker(object): status = ( epoch, self.nr_word, - "%.8f" % self.loss, - "%.8f" % loss_per_word, + _smart_round(self.loss, width=10), + _smart_round(loss_per_word, width=6), int(wps), ) self.prev_loss = float(self.loss) return status else: return None + + +def _smart_round(figure, width=10, max_decimal=4): + """Round large numbers as integers, smaller numbers as decimals.""" + n_digits = len(str(int(figure))) + n_decimal = width - (n_digits + 1) + if n_decimal <= 1: + return str(int(figure)) + else: + n_decimal = min(n_decimal, max_decimal) + format_str = "%." + str(n_decimal) + "f" + return format_str % figure diff --git a/website/docs/api/cli.md b/website/docs/api/cli.md index 2d3c13e37..a9d1a21b5 100644 --- a/website/docs/api/cli.md +++ b/website/docs/api/cli.md @@ -296,6 +296,9 @@ $ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir] [--width] | `--depth`, `-cd` | option | Depth of CNN layers. | | `--embed-rows`, `-er` | option | Number of embedding rows. | | `--dropout`, `-d` | option | Dropout rate. | +| `--batch-size`, `-bs` | option | Number of words per training batch. | +| `--max-length`, `-xw` | option | Maximum words per example. Longer examples are discarded. | +| `--min-length`, `-nw` | option | Minimum words per example. Shorter examples are discarded. | | `--seed`, `-s` | option | Seed for random number generators. | | `--n-iter`, `-i` | option | Number of iterations to pretrain. | | `--use-vectors`, `-uv` | flag | Whether to use the static vectors as input features. |