mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-28 16:04:55 +03:00
Add warnings for pointless applications of find_threshold().
This commit is contained in:
parent
9d947a42ce
commit
19dd45fe97
|
@ -7,6 +7,7 @@ from typing import Optional, Tuple, Any, Dict, List
|
|||
import numpy
|
||||
import wasabi.tables
|
||||
|
||||
from pipeline import TextCategorizer, MultiLabel_TextCategorizer
|
||||
from ..errors import Errors
|
||||
from ..training import Corpus
|
||||
from ._util import app, Arg, Opt, import_code, setup_gpu
|
||||
|
@ -111,6 +112,12 @@ def find_threshold(
|
|||
if not hasattr(pipe, "scorer"):
|
||||
raise AttributeError(Errors.E1045)
|
||||
|
||||
if isinstance(pipe, TextCategorizer):
|
||||
wasabi.msg.warn(
|
||||
"The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
|
||||
"exclusive classes. All thresholds will yield the same results."
|
||||
)
|
||||
|
||||
if not silent:
|
||||
wasabi.msg.info(
|
||||
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
|
||||
|
@ -150,8 +157,9 @@ def find_threshold(
|
|||
scores: Dict[float, float] = {}
|
||||
config_keys_full = ["components", pipe_name, *config_keys]
|
||||
table_col_widths = (10, 10)
|
||||
thresholds = numpy.linspace(0, 1, n_trials)
|
||||
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
|
||||
for threshold in numpy.linspace(0, 1, n_trials):
|
||||
for threshold in thresholds:
|
||||
# Reload pipeline with overrides specifying the new threshold.
|
||||
nlp = util.load_model(
|
||||
model,
|
||||
|
@ -183,6 +191,20 @@ def find_threshold(
|
|||
)
|
||||
|
||||
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
|
||||
|
||||
# If all scores are identical, emit warning.
|
||||
if all([score == scores[thresholds[0]] for score in scores.values()]):
|
||||
wasabi.msg.warn(
|
||||
title="All scores are identical. Verify that all settings are correct.",
|
||||
text=""
|
||||
if (
|
||||
not isinstance(pipe, MultiLabel_TextCategorizer)
|
||||
or scores_key in ("cats_macro_f", "cats_micro_f")
|
||||
)
|
||||
else "Use `cats_macro_f` or `cats_micro_f` when optimizing the threshold for `textcat_multilabel`.",
|
||||
)
|
||||
|
||||
else:
|
||||
if not silent:
|
||||
print(
|
||||
f"\nBest threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}."
|
||||
|
|
Loading…
Reference in New Issue
Block a user