From ebbbc3a61148fd5fe647b8e328b336afd588fcc8 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Mon, 24 Oct 2022 14:30:07 +0200 Subject: [PATCH] Update textcat scorer threshold behavior For `textcat` (with exclusive classes) the scorer should always use a threshold of 0.0 because there should be one predicted label per doc and the numeric score for that particular label should not matter. --- spacy/pipeline/textcat.py | 2 +- spacy/scorer.py | 4 +++- spacy/tests/pipeline/test_textcat.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index c45f819fc..1a1e2cb64 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -72,7 +72,7 @@ subword_features = true "textcat", assigns=["doc.cats"], default_config={ - "threshold": 0.5, + "threshold": 0.0, "model": DEFAULT_SINGLE_TEXTCAT_MODEL, "scorer": {"@scorers": "spacy.textcat_scorer.v1"}, }, diff --git a/spacy/scorer.py b/spacy/scorer.py index 8cd755ac4..babc9d07f 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -446,7 +446,7 @@ class Scorer: labels (Iterable[str]): The set of possible labels. Defaults to []. multi_label (bool): Whether the attribute allows multiple labels. Defaults to True. When set to False (exclusive labels), missing - gold labels are interpreted as 0.0. + gold labels are interpreted as 0.0 and the threshold is set to 0.0. positive_label (str): The positive label for a binary task with exclusive classes. Defaults to None. threshold (float): Cutoff to consider a prediction "positive". Defaults @@ -471,6 +471,8 @@ class Scorer: """ if threshold is None: threshold = 0.5 if multi_label else 0.0 + if not multi_label: + threshold = 0.0 f_per_type = {label: PRFScore() for label in labels} auc_per_type = {label: ROCAUCScore() for label in labels} labels = set(labels) diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 0bb036a33..99a38755c 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -826,7 +826,7 @@ def test_textcat_loss(multi_label: bool, expected_loss: float): def test_textcat_threshold(): # Ensure the scorer can be called with a different threshold nlp = English() - nlp.add_pipe("textcat") + nlp.add_pipe("textcat_multilabel") train_examples = [] for text, annotations in TRAIN_DATA_SINGLE_LABEL: @@ -849,7 +849,7 @@ def test_textcat_threshold(): ) pos_f = scores["cats_score"] assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0 - assert pos_f > macro_f + assert pos_f >= macro_f def test_textcat_multi_threshold():