Improve error handling around non-number scores

This commit is contained in:
Ines Montani 2020-09-24 11:29:07 +02:00
parent 4eb39b5c43
commit f69fea8b25
3 changed files with 11 additions and 2 deletions

View File

@ -214,7 +214,12 @@ def create_evaluation_callback(
def evaluate() -> Tuple[float, Dict[str, float]]: def evaluate() -> Tuple[float, Dict[str, float]]:
dev_examples = list(dev_corpus(nlp)) dev_examples = list(dev_corpus(nlp))
scores = nlp.evaluate(dev_examples) scores = nlp.evaluate(dev_examples)
# Calculate a weighted sum based on score_weights for the main score # Calculate a weighted sum based on score_weights for the main score.
# We can only consider scores that are ints/floats, not dicts like
# entity scores per type etc.
for key, value in scores.items():
if key in weights and not isinstance(value, (int, float)):
raise ValueError(Errors.E915.format(name=key, score_type=type(value)))
try: try:
weighted_score = sum( weighted_score = sum(
scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights

View File

@ -480,6 +480,10 @@ class Errors:
E201 = ("Span index out of range.") E201 = ("Span index out of range.")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E915 = ("Can't use score '{name}' to calculate final weighted score. Expected "
"float or int but got: {score_type}. To exclude the score from the "
"final score, set its weight to null in the [training.score_weights] "
"section of your training config.")
E916 = ("Can't log score for '{name}' in table: not a valid score ({score_type})") E916 = ("Can't log score for '{name}' in table: not a valid score ({score_type})")
E917 = ("Received invalid value {value} for 'state_type' in " E917 = ("Received invalid value {value} for 'state_type' in "
"TransitionBasedParser: only 'parser' or 'ner' are valid options.") "TransitionBasedParser: only 'parser' or 'ner' are valid options.")

View File

@ -49,7 +49,7 @@ def console_logger():
scores.append("{0:.2f}".format(score)) scores.append("{0:.2f}".format(score))
except TypeError: except TypeError:
err = Errors.E916.format(name=col, score_type=type(score)) err = Errors.E916.format(name=col, score_type=type(score))
raise TypeError(err) from None raise ValueError(err) from None
data = ( data = (
[info["epoch"], info["step"]] [info["epoch"], info["step"]]
+ losses + losses