From f69fea8b252ac5f28c4daac40046df507ab6f07f Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Thu, 24 Sep 2020 11:29:07 +0200 Subject: [PATCH] Improve error handling around non-number scores --- spacy/cli/train.py | 7 ++++++- spacy/errors.py | 4 ++++ spacy/training/loggers.py | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 3485a4ff2..eabc82be0 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -214,7 +214,12 @@ def create_evaluation_callback( def evaluate() -> Tuple[float, Dict[str, float]]: dev_examples = list(dev_corpus(nlp)) scores = nlp.evaluate(dev_examples) - # Calculate a weighted sum based on score_weights for the main score + # Calculate a weighted sum based on score_weights for the main score. + # We can only consider scores that are ints/floats, not dicts like + # entity scores per type etc. + for key, value in scores.items(): + if key in weights and not isinstance(value, (int, float)): + raise ValueError(Errors.E915.format(name=key, score_type=type(value))) try: weighted_score = sum( scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights diff --git a/spacy/errors.py b/spacy/errors.py index ee2091225..dce5cf51c 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -480,6 +480,10 @@ class Errors: E201 = ("Span index out of range.") # TODO: fix numbering after merging develop into master + E915 = ("Can't use score '{name}' to calculate final weighted score. Expected " + "float or int but got: {score_type}. To exclude the score from the " + "final score, set its weight to null in the [training.score_weights] " + "section of your training config.") E916 = ("Can't log score for '{name}' in table: not a valid score ({score_type})") E917 = ("Received invalid value {value} for 'state_type' in " "TransitionBasedParser: only 'parser' or 'ner' are valid options.") diff --git a/spacy/training/loggers.py b/spacy/training/loggers.py index d35b5a4bd..0f054d433 100644 --- a/spacy/training/loggers.py +++ b/spacy/training/loggers.py @@ -49,7 +49,7 @@ def console_logger(): scores.append("{0:.2f}".format(score)) except TypeError: err = Errors.E916.format(name=col, score_type=type(score)) - raise TypeError(err) from None + raise ValueError(err) from None data = ( [info["epoch"], info["step"]] + losses