mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-11 07:34:54 +03:00
Update textcat scorer threshold behavior
For `textcat` (with exclusive classes) the scorer should always use a threshold of 0.0 because there should be one predicted label per doc and the numeric score for that particular label should not matter.
This commit is contained in:
parent
84d9cb6b38
commit
ebbbc3a611
|
@ -72,7 +72,7 @@ subword_features = true
|
||||||
"textcat",
|
"textcat",
|
||||||
assigns=["doc.cats"],
|
assigns=["doc.cats"],
|
||||||
default_config={
|
default_config={
|
||||||
"threshold": 0.5,
|
"threshold": 0.0,
|
||||||
"model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
"model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
||||||
"scorer": {"@scorers": "spacy.textcat_scorer.v1"},
|
"scorer": {"@scorers": "spacy.textcat_scorer.v1"},
|
||||||
},
|
},
|
||||||
|
|
|
@ -446,7 +446,7 @@ class Scorer:
|
||||||
labels (Iterable[str]): The set of possible labels. Defaults to [].
|
labels (Iterable[str]): The set of possible labels. Defaults to [].
|
||||||
multi_label (bool): Whether the attribute allows multiple labels.
|
multi_label (bool): Whether the attribute allows multiple labels.
|
||||||
Defaults to True. When set to False (exclusive labels), missing
|
Defaults to True. When set to False (exclusive labels), missing
|
||||||
gold labels are interpreted as 0.0.
|
gold labels are interpreted as 0.0 and the threshold is set to 0.0.
|
||||||
positive_label (str): The positive label for a binary task with
|
positive_label (str): The positive label for a binary task with
|
||||||
exclusive classes. Defaults to None.
|
exclusive classes. Defaults to None.
|
||||||
threshold (float): Cutoff to consider a prediction "positive". Defaults
|
threshold (float): Cutoff to consider a prediction "positive". Defaults
|
||||||
|
@ -471,6 +471,8 @@ class Scorer:
|
||||||
"""
|
"""
|
||||||
if threshold is None:
|
if threshold is None:
|
||||||
threshold = 0.5 if multi_label else 0.0
|
threshold = 0.5 if multi_label else 0.0
|
||||||
|
if not multi_label:
|
||||||
|
threshold = 0.0
|
||||||
f_per_type = {label: PRFScore() for label in labels}
|
f_per_type = {label: PRFScore() for label in labels}
|
||||||
auc_per_type = {label: ROCAUCScore() for label in labels}
|
auc_per_type = {label: ROCAUCScore() for label in labels}
|
||||||
labels = set(labels)
|
labels = set(labels)
|
||||||
|
|
|
@ -826,7 +826,7 @@ def test_textcat_loss(multi_label: bool, expected_loss: float):
|
||||||
def test_textcat_threshold():
|
def test_textcat_threshold():
|
||||||
# Ensure the scorer can be called with a different threshold
|
# Ensure the scorer can be called with a different threshold
|
||||||
nlp = English()
|
nlp = English()
|
||||||
nlp.add_pipe("textcat")
|
nlp.add_pipe("textcat_multilabel")
|
||||||
|
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
|
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
|
||||||
|
@ -849,7 +849,7 @@ def test_textcat_threshold():
|
||||||
)
|
)
|
||||||
pos_f = scores["cats_score"]
|
pos_f = scores["cats_score"]
|
||||||
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
|
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
|
||||||
assert pos_f > macro_f
|
assert pos_f >= macro_f
|
||||||
|
|
||||||
|
|
||||||
def test_textcat_multi_threshold():
|
def test_textcat_multi_threshold():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user