mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Add words and seconds to train info
This commit is contained in:
parent
b470062153
commit
ba5f4c9b32
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional, Dict, Any, Tuple, Union, Callable, List
|
||||
from timeit import default_timer as timer
|
||||
import srsly
|
||||
import tqdm
|
||||
from pathlib import Path
|
||||
|
@ -286,9 +287,12 @@ def train_while_improving(
|
|||
]
|
||||
raw_batches = util.minibatch(raw_examples, size=8)
|
||||
|
||||
words_seen = 0
|
||||
start_time = timer()
|
||||
for step, (epoch, batch) in enumerate(train_data):
|
||||
dropout = next(dropouts)
|
||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||
|
||||
nlp.update(
|
||||
subbatch, drop=dropout, losses=losses, sgd=False, exclude=exclude
|
||||
)
|
||||
|
@ -317,6 +321,7 @@ def train_while_improving(
|
|||
else:
|
||||
score, other_scores = (None, None)
|
||||
is_best_checkpoint = None
|
||||
words_seen += sum(len(eg) for eg in batch)
|
||||
info = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
|
@ -324,6 +329,8 @@ def train_while_improving(
|
|||
"other_scores": other_scores,
|
||||
"losses": losses,
|
||||
"checkpoints": results,
|
||||
"seconds": int(timer() - start_time),
|
||||
"words": words_seen,
|
||||
}
|
||||
yield batch, info, is_best_checkpoint
|
||||
if is_best_checkpoint is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user