This commit is contained in:
Matthew Honnibal 2020-10-11 19:37:14 +00:00
parent 605f0618c4
commit f0354d42bc

View File

@ -7,6 +7,7 @@ from wasabi import Printer
import random
import sys
import shutil
import itertools
from .example import Example
from ..schemas import ConfigSchemaTraining
@ -196,9 +197,9 @@ def train_while_improving(
if not (step % eval_frequency):
if optimizer.averages:
with nlp.use_params(optimizer.averages):
score, other_scores = evaluate()
score, other_scores = evaluate(step)
else:
score, other_scores = evaluate()
score, other_scores = evaluate(step)
results.append((score, step))
is_best_checkpoint = score == max(results)[0]
else:
@ -247,8 +248,10 @@ def create_evaluation_callback(
) -> Callable[[], Tuple[float, Dict[str, float]]]:
weights = {key: value for key, value in weights.items() if value is not None}
def evaluate() -> Tuple[float, Dict[str, float]]:
dev_examples = list(dev_corpus(nlp))
def evaluate(step) -> Tuple[float, Dict[str, float]]:
# Limit dev_examples by steps, so we don't spend longer on
# the estimation than we have training.
dev_examples = list(itertools.islice(dev_corpus(nlp), step))
try:
scores = nlp.evaluate(dev_examples)
except KeyError as e: