This commit is contained in:
Matthew Honnibal 2020-10-11 19:37:14 +00:00
parent 605f0618c4
commit f0354d42bc

View File

@ -7,6 +7,7 @@ from wasabi import Printer
import random import random
import sys import sys
import shutil import shutil
import itertools
from .example import Example from .example import Example
from ..schemas import ConfigSchemaTraining from ..schemas import ConfigSchemaTraining
@ -196,9 +197,9 @@ def train_while_improving(
if not (step % eval_frequency): if not (step % eval_frequency):
if optimizer.averages: if optimizer.averages:
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
score, other_scores = evaluate() score, other_scores = evaluate(step)
else: else:
score, other_scores = evaluate() score, other_scores = evaluate(step)
results.append((score, step)) results.append((score, step))
is_best_checkpoint = score == max(results)[0] is_best_checkpoint = score == max(results)[0]
else: else:
@ -247,8 +248,10 @@ def create_evaluation_callback(
) -> Callable[[], Tuple[float, Dict[str, float]]]: ) -> Callable[[], Tuple[float, Dict[str, float]]]:
weights = {key: value for key, value in weights.items() if value is not None} weights = {key: value for key, value in weights.items() if value is not None}
def evaluate() -> Tuple[float, Dict[str, float]]: def evaluate(step) -> Tuple[float, Dict[str, float]]:
dev_examples = list(dev_corpus(nlp)) # Limit dev_examples by steps, so we don't spend longer on
# the estimation than we have training.
dev_examples = list(itertools.islice(dev_corpus(nlp), step))
try: try:
scores = nlp.evaluate(dev_examples) scores = nlp.evaluate(dev_examples)
except KeyError as e: except KeyError as e: