mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 23:24:54 +03:00
Update textcat scorer threshold behavior
For `textcat` (with exclusive classes) the scorer should always use a threshold of 0.0 because there should be one predicted label per doc and the numeric score for that particular label should not matter.
This commit is contained in:
parent
84d9cb6b38
commit
ebbbc3a611
|
@ -72,7 +72,7 @@ subword_features = true
|
|||
"textcat",
|
||||
assigns=["doc.cats"],
|
||||
default_config={
|
||||
"threshold": 0.5,
|
||||
"threshold": 0.0,
|
||||
"model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
||||
"scorer": {"@scorers": "spacy.textcat_scorer.v1"},
|
||||
},
|
||||
|
|
|
@ -446,7 +446,7 @@ class Scorer:
|
|||
labels (Iterable[str]): The set of possible labels. Defaults to [].
|
||||
multi_label (bool): Whether the attribute allows multiple labels.
|
||||
Defaults to True. When set to False (exclusive labels), missing
|
||||
gold labels are interpreted as 0.0.
|
||||
gold labels are interpreted as 0.0 and the threshold is set to 0.0.
|
||||
positive_label (str): The positive label for a binary task with
|
||||
exclusive classes. Defaults to None.
|
||||
threshold (float): Cutoff to consider a prediction "positive". Defaults
|
||||
|
@ -471,6 +471,8 @@ class Scorer:
|
|||
"""
|
||||
if threshold is None:
|
||||
threshold = 0.5 if multi_label else 0.0
|
||||
if not multi_label:
|
||||
threshold = 0.0
|
||||
f_per_type = {label: PRFScore() for label in labels}
|
||||
auc_per_type = {label: ROCAUCScore() for label in labels}
|
||||
labels = set(labels)
|
||||
|
|
|
@ -826,7 +826,7 @@ def test_textcat_loss(multi_label: bool, expected_loss: float):
|
|||
def test_textcat_threshold():
|
||||
# Ensure the scorer can be called with a different threshold
|
||||
nlp = English()
|
||||
nlp.add_pipe("textcat")
|
||||
nlp.add_pipe("textcat_multilabel")
|
||||
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
|
||||
|
@ -849,7 +849,7 @@ def test_textcat_threshold():
|
|||
)
|
||||
pos_f = scores["cats_score"]
|
||||
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
|
||||
assert pos_f > macro_f
|
||||
assert pos_f >= macro_f
|
||||
|
||||
|
||||
def test_textcat_multi_threshold():
|
||||
|
|
Loading…
Reference in New Issue
Block a user