mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
Print epochs
This commit is contained in:
parent
98c026195b
commit
9e3695de6b
|
@ -296,16 +296,19 @@ def train(
|
||||||
|
|
||||||
|
|
||||||
def create_train_batches(nlp, corpus, cfg):
|
def create_train_batches(nlp, corpus, cfg):
|
||||||
epochs_todo = cfg.get("max_epochs", 0)
|
max_epochs = cfg.get("max_epochs", 0)
|
||||||
|
train_examples = list(corpus.train_dataset(
|
||||||
|
nlp,
|
||||||
|
shuffle=True,
|
||||||
|
gold_preproc=cfg["gold_preproc"],
|
||||||
|
max_length=cfg["max_length"]
|
||||||
|
))
|
||||||
|
|
||||||
|
epoch = 0
|
||||||
while True:
|
while True:
|
||||||
train_examples = list(corpus.train_dataset(
|
|
||||||
nlp,
|
|
||||||
shuffle=True,
|
|
||||||
gold_preproc=cfg["gold_preproc"],
|
|
||||||
max_length=cfg["max_length"]
|
|
||||||
))
|
|
||||||
if len(train_examples) == 0:
|
if len(train_examples) == 0:
|
||||||
raise ValueError(Errors.E988)
|
raise ValueError(Errors.E988)
|
||||||
|
epoch += 1
|
||||||
batches = util.minibatch_by_words(
|
batches = util.minibatch_by_words(
|
||||||
train_examples,
|
train_examples,
|
||||||
size=cfg["batch_size"],
|
size=cfg["batch_size"],
|
||||||
|
@ -314,15 +317,12 @@ def create_train_batches(nlp, corpus, cfg):
|
||||||
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
|
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
|
||||||
try:
|
try:
|
||||||
first = next(batches)
|
first = next(batches)
|
||||||
yield first
|
yield epoch, first
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise ValueError(Errors.E986)
|
raise ValueError(Errors.E986)
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
yield batch
|
yield epoch, batch
|
||||||
epochs_todo -= 1
|
if max_epochs >= 1 and epoch >= max_epochs:
|
||||||
# We intentionally compare exactly to 0 here, so that max_epochs < 1
|
|
||||||
# will not break.
|
|
||||||
if epochs_todo == 0:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
@ -427,7 +427,7 @@ def train_while_improving(
|
||||||
(nlp.make_doc(rt["text"]) for rt in raw_text), size=8
|
(nlp.make_doc(rt["text"]) for rt in raw_text), size=8
|
||||||
)
|
)
|
||||||
|
|
||||||
for step, batch in enumerate(train_data):
|
for step, (epoch, batch) in enumerate(train_data):
|
||||||
dropout = next(dropouts)
|
dropout = next(dropouts)
|
||||||
with nlp.select_pipes(enable=to_enable):
|
with nlp.select_pipes(enable=to_enable):
|
||||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||||
|
@ -449,6 +449,7 @@ def train_while_improving(
|
||||||
score, other_scores = (None, None)
|
score, other_scores = (None, None)
|
||||||
is_best_checkpoint = None
|
is_best_checkpoint = None
|
||||||
info = {
|
info = {
|
||||||
|
"epoch": epoch,
|
||||||
"step": step,
|
"step": step,
|
||||||
"score": score,
|
"score": score,
|
||||||
"other_scores": other_scores,
|
"other_scores": other_scores,
|
||||||
|
@ -487,9 +488,9 @@ def setup_printer(training, nlp):
|
||||||
score_widths = [max(len(col), 6) for col in score_cols]
|
score_widths = [max(len(col), 6) for col in score_cols]
|
||||||
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
|
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
|
||||||
loss_widths = [max(len(col), 8) for col in loss_cols]
|
loss_widths = [max(len(col), 8) for col in loss_cols]
|
||||||
table_header = ["#"] + loss_cols + score_cols + ["Score"]
|
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"]
|
||||||
table_header = [col.upper() for col in table_header]
|
table_header = [col.upper() for col in table_header]
|
||||||
table_widths = [6] + loss_widths + score_widths + [6]
|
table_widths = [3, 6] + loss_widths + score_widths + [6]
|
||||||
table_aligns = ["r" for _ in table_widths]
|
table_aligns = ["r" for _ in table_widths]
|
||||||
|
|
||||||
msg.row(table_header, widths=table_widths)
|
msg.row(table_header, widths=table_widths)
|
||||||
|
@ -521,7 +522,7 @@ def setup_printer(training, nlp):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
data = (
|
data = (
|
||||||
[info["step"]] + losses + scores + ["{0:.2f}".format(float(info["score"]))]
|
[info["epoch"], info["step"]] + losses + scores + ["{0:.2f}".format(float(info["score"]))]
|
||||||
)
|
)
|
||||||
msg.row(data, widths=table_widths, aligns=table_aligns)
|
msg.row(data, widths=table_widths, aligns=table_aligns)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user