Add warnings for pointless applications of find_threshold().

This commit is contained in:
Raphael Mitsch 2022-10-28 12:27:23 +02:00
parent 9d947a42ce
commit 19dd45fe97

View File

@ -7,6 +7,7 @@ from typing import Optional, Tuple, Any, Dict, List
import numpy import numpy
import wasabi.tables import wasabi.tables
from pipeline import TextCategorizer, MultiLabel_TextCategorizer
from ..errors import Errors from ..errors import Errors
from ..training import Corpus from ..training import Corpus
from ._util import app, Arg, Opt, import_code, setup_gpu from ._util import app, Arg, Opt, import_code, setup_gpu
@ -111,6 +112,12 @@ def find_threshold(
if not hasattr(pipe, "scorer"): if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045) raise AttributeError(Errors.E1045)
if isinstance(pipe, TextCategorizer):
wasabi.msg.warn(
"The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
"exclusive classes. All thresholds will yield the same results."
)
if not silent: if not silent:
wasabi.msg.info( wasabi.msg.info(
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} " title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
@ -150,8 +157,9 @@ def find_threshold(
scores: Dict[float, float] = {} scores: Dict[float, float] = {}
config_keys_full = ["components", pipe_name, *config_keys] config_keys_full = ["components", pipe_name, *config_keys]
table_col_widths = (10, 10) table_col_widths = (10, 10)
thresholds = numpy.linspace(0, 1, n_trials)
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths)) print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
for threshold in numpy.linspace(0, 1, n_trials): for threshold in thresholds:
# Reload pipeline with overrides specifying the new threshold. # Reload pipeline with overrides specifying the new threshold.
nlp = util.load_model( nlp = util.load_model(
model, model,
@ -183,9 +191,23 @@ def find_threshold(
) )
best_threshold = max(scores.keys(), key=(lambda key: scores[key])) best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
if not silent:
print( # If all scores are identical, emit warning.
f"\nBest threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}." if all([score == scores[thresholds[0]] for score in scores.values()]):
wasabi.msg.warn(
title="All scores are identical. Verify that all settings are correct.",
text=""
if (
not isinstance(pipe, MultiLabel_TextCategorizer)
or scores_key in ("cats_macro_f", "cats_micro_f")
)
else "Use `cats_macro_f` or `cats_micro_f` when optimizing the threshold for `textcat_multilabel`.",
) )
else:
if not silent:
print(
f"\nBest threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}."
)
return best_threshold, scores[best_threshold], scores return best_threshold, scores[best_threshold], scores