Update textcat scorer threshold behavior

For `textcat` (with exclusive classes) the scorer should always use a
threshold of 0.0 because there should be one predicted label per doc and
the numeric score for that particular label should not matter.
This commit is contained in:
Adriane Boyd 2022-10-24 14:30:07 +02:00
parent 84d9cb6b38
commit ebbbc3a611
3 changed files with 6 additions and 4 deletions

View File

@ -72,7 +72,7 @@ subword_features = true
"textcat", "textcat",
assigns=["doc.cats"], assigns=["doc.cats"],
default_config={ default_config={
"threshold": 0.5, "threshold": 0.0,
"model": DEFAULT_SINGLE_TEXTCAT_MODEL, "model": DEFAULT_SINGLE_TEXTCAT_MODEL,
"scorer": {"@scorers": "spacy.textcat_scorer.v1"}, "scorer": {"@scorers": "spacy.textcat_scorer.v1"},
}, },

View File

@ -446,7 +446,7 @@ class Scorer:
labels (Iterable[str]): The set of possible labels. Defaults to []. labels (Iterable[str]): The set of possible labels. Defaults to [].
multi_label (bool): Whether the attribute allows multiple labels. multi_label (bool): Whether the attribute allows multiple labels.
Defaults to True. When set to False (exclusive labels), missing Defaults to True. When set to False (exclusive labels), missing
gold labels are interpreted as 0.0. gold labels are interpreted as 0.0 and the threshold is set to 0.0.
positive_label (str): The positive label for a binary task with positive_label (str): The positive label for a binary task with
exclusive classes. Defaults to None. exclusive classes. Defaults to None.
threshold (float): Cutoff to consider a prediction "positive". Defaults threshold (float): Cutoff to consider a prediction "positive". Defaults
@ -471,6 +471,8 @@ class Scorer:
""" """
if threshold is None: if threshold is None:
threshold = 0.5 if multi_label else 0.0 threshold = 0.5 if multi_label else 0.0
if not multi_label:
threshold = 0.0
f_per_type = {label: PRFScore() for label in labels} f_per_type = {label: PRFScore() for label in labels}
auc_per_type = {label: ROCAUCScore() for label in labels} auc_per_type = {label: ROCAUCScore() for label in labels}
labels = set(labels) labels = set(labels)

View File

@ -826,7 +826,7 @@ def test_textcat_loss(multi_label: bool, expected_loss: float):
def test_textcat_threshold(): def test_textcat_threshold():
# Ensure the scorer can be called with a different threshold # Ensure the scorer can be called with a different threshold
nlp = English() nlp = English()
nlp.add_pipe("textcat") nlp.add_pipe("textcat_multilabel")
train_examples = [] train_examples = []
for text, annotations in TRAIN_DATA_SINGLE_LABEL: for text, annotations in TRAIN_DATA_SINGLE_LABEL:
@ -849,7 +849,7 @@ def test_textcat_threshold():
) )
pos_f = scores["cats_score"] pos_f = scores["cats_score"]
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0 assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
assert pos_f > macro_f assert pos_f >= macro_f
def test_textcat_multi_threshold(): def test_textcat_multi_threshold():