mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-28 16:04:55 +03:00
Add warnings for pointless applications of find_threshold().
This commit is contained in:
parent
9d947a42ce
commit
19dd45fe97
|
@ -7,6 +7,7 @@ from typing import Optional, Tuple, Any, Dict, List
|
||||||
import numpy
|
import numpy
|
||||||
import wasabi.tables
|
import wasabi.tables
|
||||||
|
|
||||||
|
from pipeline import TextCategorizer, MultiLabel_TextCategorizer
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..training import Corpus
|
from ..training import Corpus
|
||||||
from ._util import app, Arg, Opt, import_code, setup_gpu
|
from ._util import app, Arg, Opt, import_code, setup_gpu
|
||||||
|
@ -111,6 +112,12 @@ def find_threshold(
|
||||||
if not hasattr(pipe, "scorer"):
|
if not hasattr(pipe, "scorer"):
|
||||||
raise AttributeError(Errors.E1045)
|
raise AttributeError(Errors.E1045)
|
||||||
|
|
||||||
|
if isinstance(pipe, TextCategorizer):
|
||||||
|
wasabi.msg.warn(
|
||||||
|
"The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
|
||||||
|
"exclusive classes. All thresholds will yield the same results."
|
||||||
|
)
|
||||||
|
|
||||||
if not silent:
|
if not silent:
|
||||||
wasabi.msg.info(
|
wasabi.msg.info(
|
||||||
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
|
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
|
||||||
|
@ -150,8 +157,9 @@ def find_threshold(
|
||||||
scores: Dict[float, float] = {}
|
scores: Dict[float, float] = {}
|
||||||
config_keys_full = ["components", pipe_name, *config_keys]
|
config_keys_full = ["components", pipe_name, *config_keys]
|
||||||
table_col_widths = (10, 10)
|
table_col_widths = (10, 10)
|
||||||
|
thresholds = numpy.linspace(0, 1, n_trials)
|
||||||
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
|
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
|
||||||
for threshold in numpy.linspace(0, 1, n_trials):
|
for threshold in thresholds:
|
||||||
# Reload pipeline with overrides specifying the new threshold.
|
# Reload pipeline with overrides specifying the new threshold.
|
||||||
nlp = util.load_model(
|
nlp = util.load_model(
|
||||||
model,
|
model,
|
||||||
|
@ -183,9 +191,23 @@ def find_threshold(
|
||||||
)
|
)
|
||||||
|
|
||||||
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
|
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
|
||||||
if not silent:
|
|
||||||
print(
|
# If all scores are identical, emit warning.
|
||||||
f"\nBest threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}."
|
if all([score == scores[thresholds[0]] for score in scores.values()]):
|
||||||
|
wasabi.msg.warn(
|
||||||
|
title="All scores are identical. Verify that all settings are correct.",
|
||||||
|
text=""
|
||||||
|
if (
|
||||||
|
not isinstance(pipe, MultiLabel_TextCategorizer)
|
||||||
|
or scores_key in ("cats_macro_f", "cats_micro_f")
|
||||||
|
)
|
||||||
|
else "Use `cats_macro_f` or `cats_micro_f` when optimizing the threshold for `textcat_multilabel`.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if not silent:
|
||||||
|
print(
|
||||||
|
f"\nBest threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}."
|
||||||
|
)
|
||||||
|
|
||||||
return best_threshold, scores[best_threshold], scores
|
return best_threshold, scores[best_threshold], scores
|
||||||
|
|
Loading…
Reference in New Issue
Block a user