From 19dd45fe9753491476fc3639b24a6c8a43809cda Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 28 Oct 2022 12:27:23 +0200 Subject: [PATCH] Add warnings for pointless applications of find_threshold(). --- spacy/cli/find_threshold.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/spacy/cli/find_threshold.py b/spacy/cli/find_threshold.py index 9f1b14475..b62da3b2b 100644 --- a/spacy/cli/find_threshold.py +++ b/spacy/cli/find_threshold.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple, Any, Dict, List import numpy import wasabi.tables +from pipeline import TextCategorizer, MultiLabel_TextCategorizer from ..errors import Errors from ..training import Corpus from ._util import app, Arg, Opt, import_code, setup_gpu @@ -111,6 +112,12 @@ def find_threshold( if not hasattr(pipe, "scorer"): raise AttributeError(Errors.E1045) + if isinstance(pipe, TextCategorizer): + wasabi.msg.warn( + "The `textcat` component doesn't use a threshold as it's not applicable to the concept of " + "exclusive classes. All thresholds will yield the same results." + ) + if not silent: wasabi.msg.info( title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} " @@ -150,8 +157,9 @@ def find_threshold( scores: Dict[float, float] = {} config_keys_full = ["components", pipe_name, *config_keys] table_col_widths = (10, 10) + thresholds = numpy.linspace(0, 1, n_trials) print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths)) - for threshold in numpy.linspace(0, 1, n_trials): + for threshold in thresholds: # Reload pipeline with overrides specifying the new threshold. nlp = util.load_model( model, @@ -183,9 +191,23 @@ def find_threshold( ) best_threshold = max(scores.keys(), key=(lambda key: scores[key])) - if not silent: - print( - f"\nBest threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}." + + # If all scores are identical, emit warning. + if all([score == scores[thresholds[0]] for score in scores.values()]): + wasabi.msg.warn( + title="All scores are identical. Verify that all settings are correct.", + text="" + if ( + not isinstance(pipe, MultiLabel_TextCategorizer) + or scores_key in ("cats_macro_f", "cats_micro_f") + ) + else "Use `cats_macro_f` or `cats_micro_f` when optimizing the threshold for `textcat_multilabel`.", ) + else: + if not silent: + print( + f"\nBest threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}." + ) + return best_threshold, scores[best_threshold], scores