mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Improve error handling around non-number scores
This commit is contained in:
parent
4eb39b5c43
commit
f69fea8b25
|
@ -214,7 +214,12 @@ def create_evaluation_callback(
|
||||||
def evaluate() -> Tuple[float, Dict[str, float]]:
|
def evaluate() -> Tuple[float, Dict[str, float]]:
|
||||||
dev_examples = list(dev_corpus(nlp))
|
dev_examples = list(dev_corpus(nlp))
|
||||||
scores = nlp.evaluate(dev_examples)
|
scores = nlp.evaluate(dev_examples)
|
||||||
# Calculate a weighted sum based on score_weights for the main score
|
# Calculate a weighted sum based on score_weights for the main score.
|
||||||
|
# We can only consider scores that are ints/floats, not dicts like
|
||||||
|
# entity scores per type etc.
|
||||||
|
for key, value in scores.items():
|
||||||
|
if key in weights and not isinstance(value, (int, float)):
|
||||||
|
raise ValueError(Errors.E915.format(name=key, score_type=type(value)))
|
||||||
try:
|
try:
|
||||||
weighted_score = sum(
|
weighted_score = sum(
|
||||||
scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights
|
scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights
|
||||||
|
|
|
@ -480,6 +480,10 @@ class Errors:
|
||||||
E201 = ("Span index out of range.")
|
E201 = ("Span index out of range.")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E915 = ("Can't use score '{name}' to calculate final weighted score. Expected "
|
||||||
|
"float or int but got: {score_type}. To exclude the score from the "
|
||||||
|
"final score, set its weight to null in the [training.score_weights] "
|
||||||
|
"section of your training config.")
|
||||||
E916 = ("Can't log score for '{name}' in table: not a valid score ({score_type})")
|
E916 = ("Can't log score for '{name}' in table: not a valid score ({score_type})")
|
||||||
E917 = ("Received invalid value {value} for 'state_type' in "
|
E917 = ("Received invalid value {value} for 'state_type' in "
|
||||||
"TransitionBasedParser: only 'parser' or 'ner' are valid options.")
|
"TransitionBasedParser: only 'parser' or 'ner' are valid options.")
|
||||||
|
|
|
@ -49,7 +49,7 @@ def console_logger():
|
||||||
scores.append("{0:.2f}".format(score))
|
scores.append("{0:.2f}".format(score))
|
||||||
except TypeError:
|
except TypeError:
|
||||||
err = Errors.E916.format(name=col, score_type=type(score))
|
err = Errors.E916.format(name=col, score_type=type(score))
|
||||||
raise TypeError(err) from None
|
raise ValueError(err) from None
|
||||||
data = (
|
data = (
|
||||||
[info["epoch"], info["step"]]
|
[info["epoch"], info["step"]]
|
||||||
+ losses
|
+ losses
|
||||||
|
|
Loading…
Reference in New Issue
Block a user