diff --git a/spacy/cli/train.py b/spacy/cli/train.py index d199236b9..46e73cd88 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -296,16 +296,19 @@ def train( def create_train_batches(nlp, corpus, cfg): - epochs_todo = cfg.get("max_epochs", 0) + max_epochs = cfg.get("max_epochs", 0) + train_examples = list(corpus.train_dataset( + nlp, + shuffle=True, + gold_preproc=cfg["gold_preproc"], + max_length=cfg["max_length"] + )) + + epoch = 0 while True: - train_examples = list(corpus.train_dataset( - nlp, - shuffle=True, - gold_preproc=cfg["gold_preproc"], - max_length=cfg["max_length"] - )) if len(train_examples) == 0: raise ValueError(Errors.E988) + epoch += 1 batches = util.minibatch_by_words( train_examples, size=cfg["batch_size"], @@ -314,15 +317,12 @@ def create_train_batches(nlp, corpus, cfg): # make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop try: first = next(batches) - yield first + yield epoch, first except StopIteration: raise ValueError(Errors.E986) for batch in batches: - yield batch - epochs_todo -= 1 - # We intentionally compare exactly to 0 here, so that max_epochs < 1 - # will not break. - if epochs_todo == 0: + yield epoch, batch + if max_epochs >= 1 and epoch >= max_epochs: break @@ -427,7 +427,7 @@ def train_while_improving( (nlp.make_doc(rt["text"]) for rt in raw_text), size=8 ) - for step, batch in enumerate(train_data): + for step, (epoch, batch) in enumerate(train_data): dropout = next(dropouts) with nlp.select_pipes(enable=to_enable): for subbatch in subdivide_batch(batch, accumulate_gradient): @@ -449,6 +449,7 @@ def train_while_improving( score, other_scores = (None, None) is_best_checkpoint = None info = { + "epoch": epoch, "step": step, "score": score, "other_scores": other_scores, @@ -487,9 +488,9 @@ def setup_printer(training, nlp): score_widths = [max(len(col), 6) for col in score_cols] loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names] loss_widths = [max(len(col), 8) for col in loss_cols] - table_header = ["#"] + loss_cols + score_cols + ["Score"] + table_header = ["E", "#"] + loss_cols + score_cols + ["Score"] table_header = [col.upper() for col in table_header] - table_widths = [6] + loss_widths + score_widths + [6] + table_widths = [3, 6] + loss_widths + score_widths + [6] table_aligns = ["r" for _ in table_widths] msg.row(table_header, widths=table_widths) @@ -521,7 +522,7 @@ def setup_printer(training, nlp): ) ) data = ( - [info["step"]] + losses + scores + ["{0:.2f}".format(float(info["score"]))] + [info["epoch"], info["step"]] + losses + scores + ["{0:.2f}".format(float(info["score"]))] ) msg.row(data, widths=table_widths, aligns=table_aligns)