Print epochs

This commit is contained in:
Matthew Honnibal 2020-06-25 21:18:08 +02:00
parent 98c026195b
commit 9e3695de6b

View File

@ -296,16 +296,19 @@ def train(
def create_train_batches(nlp, corpus, cfg):
epochs_todo = cfg.get("max_epochs", 0)
max_epochs = cfg.get("max_epochs", 0)
train_examples = list(corpus.train_dataset(
nlp,
shuffle=True,
gold_preproc=cfg["gold_preproc"],
max_length=cfg["max_length"]
))
epoch = 0
while True:
train_examples = list(corpus.train_dataset(
nlp,
shuffle=True,
gold_preproc=cfg["gold_preproc"],
max_length=cfg["max_length"]
))
if len(train_examples) == 0:
raise ValueError(Errors.E988)
epoch += 1
batches = util.minibatch_by_words(
train_examples,
size=cfg["batch_size"],
@ -314,15 +317,12 @@ def create_train_batches(nlp, corpus, cfg):
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
try:
first = next(batches)
yield first
yield epoch, first
except StopIteration:
raise ValueError(Errors.E986)
for batch in batches:
yield batch
epochs_todo -= 1
# We intentionally compare exactly to 0 here, so that max_epochs < 1
# will not break.
if epochs_todo == 0:
yield epoch, batch
if max_epochs >= 1 and epoch >= max_epochs:
break
@ -427,7 +427,7 @@ def train_while_improving(
(nlp.make_doc(rt["text"]) for rt in raw_text), size=8
)
for step, batch in enumerate(train_data):
for step, (epoch, batch) in enumerate(train_data):
dropout = next(dropouts)
with nlp.select_pipes(enable=to_enable):
for subbatch in subdivide_batch(batch, accumulate_gradient):
@ -449,6 +449,7 @@ def train_while_improving(
score, other_scores = (None, None)
is_best_checkpoint = None
info = {
"epoch": epoch,
"step": step,
"score": score,
"other_scores": other_scores,
@ -487,9 +488,9 @@ def setup_printer(training, nlp):
score_widths = [max(len(col), 6) for col in score_cols]
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
loss_widths = [max(len(col), 8) for col in loss_cols]
table_header = ["#"] + loss_cols + score_cols + ["Score"]
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"]
table_header = [col.upper() for col in table_header]
table_widths = [6] + loss_widths + score_widths + [6]
table_widths = [3, 6] + loss_widths + score_widths + [6]
table_aligns = ["r" for _ in table_widths]
msg.row(table_header, widths=table_widths)
@ -521,7 +522,7 @@ def setup_printer(training, nlp):
)
)
data = (
[info["step"]] + losses + scores + ["{0:.2f}".format(float(info["score"]))]
[info["epoch"], info["step"]] + losses + scores + ["{0:.2f}".format(float(info["score"]))]
)
msg.row(data, widths=table_widths, aligns=table_aligns)