mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-10 16:40:34 +03:00
Print epochs
This commit is contained in:
parent
98c026195b
commit
9e3695de6b
|
@ -296,16 +296,19 @@ def train(
|
|||
|
||||
|
||||
def create_train_batches(nlp, corpus, cfg):
|
||||
epochs_todo = cfg.get("max_epochs", 0)
|
||||
max_epochs = cfg.get("max_epochs", 0)
|
||||
train_examples = list(corpus.train_dataset(
|
||||
nlp,
|
||||
shuffle=True,
|
||||
gold_preproc=cfg["gold_preproc"],
|
||||
max_length=cfg["max_length"]
|
||||
))
|
||||
|
||||
epoch = 0
|
||||
while True:
|
||||
train_examples = list(corpus.train_dataset(
|
||||
nlp,
|
||||
shuffle=True,
|
||||
gold_preproc=cfg["gold_preproc"],
|
||||
max_length=cfg["max_length"]
|
||||
))
|
||||
if len(train_examples) == 0:
|
||||
raise ValueError(Errors.E988)
|
||||
epoch += 1
|
||||
batches = util.minibatch_by_words(
|
||||
train_examples,
|
||||
size=cfg["batch_size"],
|
||||
|
@ -314,15 +317,12 @@ def create_train_batches(nlp, corpus, cfg):
|
|||
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
|
||||
try:
|
||||
first = next(batches)
|
||||
yield first
|
||||
yield epoch, first
|
||||
except StopIteration:
|
||||
raise ValueError(Errors.E986)
|
||||
for batch in batches:
|
||||
yield batch
|
||||
epochs_todo -= 1
|
||||
# We intentionally compare exactly to 0 here, so that max_epochs < 1
|
||||
# will not break.
|
||||
if epochs_todo == 0:
|
||||
yield epoch, batch
|
||||
if max_epochs >= 1 and epoch >= max_epochs:
|
||||
break
|
||||
|
||||
|
||||
|
@ -427,7 +427,7 @@ def train_while_improving(
|
|||
(nlp.make_doc(rt["text"]) for rt in raw_text), size=8
|
||||
)
|
||||
|
||||
for step, batch in enumerate(train_data):
|
||||
for step, (epoch, batch) in enumerate(train_data):
|
||||
dropout = next(dropouts)
|
||||
with nlp.select_pipes(enable=to_enable):
|
||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||
|
@ -449,6 +449,7 @@ def train_while_improving(
|
|||
score, other_scores = (None, None)
|
||||
is_best_checkpoint = None
|
||||
info = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"score": score,
|
||||
"other_scores": other_scores,
|
||||
|
@ -487,9 +488,9 @@ def setup_printer(training, nlp):
|
|||
score_widths = [max(len(col), 6) for col in score_cols]
|
||||
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
|
||||
loss_widths = [max(len(col), 8) for col in loss_cols]
|
||||
table_header = ["#"] + loss_cols + score_cols + ["Score"]
|
||||
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"]
|
||||
table_header = [col.upper() for col in table_header]
|
||||
table_widths = [6] + loss_widths + score_widths + [6]
|
||||
table_widths = [3, 6] + loss_widths + score_widths + [6]
|
||||
table_aligns = ["r" for _ in table_widths]
|
||||
|
||||
msg.row(table_header, widths=table_widths)
|
||||
|
@ -521,7 +522,7 @@ def setup_printer(training, nlp):
|
|||
)
|
||||
)
|
||||
data = (
|
||||
[info["step"]] + losses + scores + ["{0:.2f}".format(float(info["score"]))]
|
||||
[info["epoch"], info["step"]] + losses + scores + ["{0:.2f}".format(float(info["score"]))]
|
||||
)
|
||||
msg.row(data, widths=table_widths, aligns=table_aligns)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user