mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
Update textcat scorer threshold behavior (#11696)
* Update textcat scorer threshold behavior For `textcat` (with exclusive classes) the scorer should always use a threshold of 0.0 because there should be one predicted label per doc and the numeric score for that particular label should not matter. * Rename to test_textcat_multilabel_threshold * Remove all uses of threshold for multi_label=False * Update Scorer.score_cats API docs * Add tests for score_cats with thresholds * Update textcat API docs * Fix types * Convert threshold back to float * Fix threshold type in docstring * Improve formatting in Scorer API docs
This commit is contained in:
parent
f7edd84b44
commit
420b1d854b
|
@ -72,7 +72,7 @@ subword_features = true
|
|||
"textcat",
|
||||
assigns=["doc.cats"],
|
||||
default_config={
|
||||
"threshold": 0.5,
|
||||
"threshold": 0.0,
|
||||
"model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
||||
"scorer": {"@scorers": "spacy.textcat_scorer.v1"},
|
||||
},
|
||||
|
@ -144,7 +144,8 @@ class TextCategorizer(TrainablePipe):
|
|||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
threshold (float): Unused, not needed for single-label (exclusive
|
||||
classes) classification.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||
Scorer.score_cats for the attribute "cats".
|
||||
|
||||
|
@ -154,7 +155,7 @@ class TextCategorizer(TrainablePipe):
|
|||
self.model = model
|
||||
self.name = name
|
||||
self._rehearsal_model = None
|
||||
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
||||
cfg: Dict[str, Any] = {"labels": [], "threshold": threshold, "positive_label": None}
|
||||
self.cfg = dict(cfg)
|
||||
self.scorer = scorer
|
||||
|
||||
|
|
|
@ -446,7 +446,7 @@ class Scorer:
|
|||
labels (Iterable[str]): The set of possible labels. Defaults to [].
|
||||
multi_label (bool): Whether the attribute allows multiple labels.
|
||||
Defaults to True. When set to False (exclusive labels), missing
|
||||
gold labels are interpreted as 0.0.
|
||||
gold labels are interpreted as 0.0 and the threshold is set to 0.0.
|
||||
positive_label (str): The positive label for a binary task with
|
||||
exclusive classes. Defaults to None.
|
||||
threshold (float): Cutoff to consider a prediction "positive". Defaults
|
||||
|
@ -471,6 +471,8 @@ class Scorer:
|
|||
"""
|
||||
if threshold is None:
|
||||
threshold = 0.5 if multi_label else 0.0
|
||||
if not multi_label:
|
||||
threshold = 0.0
|
||||
f_per_type = {label: PRFScore() for label in labels}
|
||||
auc_per_type = {label: ROCAUCScore() for label in labels}
|
||||
labels = set(labels)
|
||||
|
@ -505,20 +507,18 @@ class Scorer:
|
|||
# Get the highest-scoring for each.
|
||||
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
|
||||
gold_label, gold_score = max(gold_cats.items(), key=lambda it: it[1])
|
||||
if pred_label == gold_label and pred_score >= threshold:
|
||||
if pred_label == gold_label:
|
||||
f_per_type[pred_label].tp += 1
|
||||
else:
|
||||
f_per_type[gold_label].fn += 1
|
||||
if pred_score >= threshold:
|
||||
f_per_type[pred_label].fp += 1
|
||||
f_per_type[pred_label].fp += 1
|
||||
elif gold_cats:
|
||||
gold_label, gold_score = max(gold_cats, key=lambda it: it[1])
|
||||
if gold_score > 0:
|
||||
f_per_type[gold_label].fn += 1
|
||||
elif pred_cats:
|
||||
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
|
||||
if pred_score >= threshold:
|
||||
f_per_type[pred_label].fp += 1
|
||||
f_per_type[pred_label].fp += 1
|
||||
micro_prf = PRFScore()
|
||||
for label_prf in f_per_type.values():
|
||||
micro_prf.tp += label_prf.tp
|
||||
|
|
|
@ -823,10 +823,10 @@ def test_textcat_loss(multi_label: bool, expected_loss: float):
|
|||
assert loss == expected_loss
|
||||
|
||||
|
||||
def test_textcat_threshold():
|
||||
def test_textcat_multilabel_threshold():
|
||||
# Ensure the scorer can be called with a different threshold
|
||||
nlp = English()
|
||||
nlp.add_pipe("textcat")
|
||||
nlp.add_pipe("textcat_multilabel")
|
||||
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
|
||||
|
@ -849,7 +849,7 @@ def test_textcat_threshold():
|
|||
)
|
||||
pos_f = scores["cats_score"]
|
||||
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
|
||||
assert pos_f > macro_f
|
||||
assert pos_f >= macro_f
|
||||
|
||||
|
||||
def test_textcat_multi_threshold():
|
||||
|
|
|
@ -474,3 +474,50 @@ def test_prf_score():
|
|||
assert (a.precision, a.recall, a.fscore) == approx(
|
||||
(c.precision, c.recall, c.fscore)
|
||||
)
|
||||
|
||||
|
||||
def test_score_cats(en_tokenizer):
|
||||
text = "some text"
|
||||
gold_doc = en_tokenizer(text)
|
||||
gold_doc.cats = {"POSITIVE": 1.0, "NEGATIVE": 0.0}
|
||||
pred_doc = en_tokenizer(text)
|
||||
pred_doc.cats = {"POSITIVE": 0.75, "NEGATIVE": 0.25}
|
||||
example = Example(pred_doc, gold_doc)
|
||||
# threshold is ignored for multi_label=False
|
||||
scores1 = Scorer.score_cats(
|
||||
[example],
|
||||
"cats",
|
||||
labels=list(gold_doc.cats.keys()),
|
||||
multi_label=False,
|
||||
positive_label="POSITIVE",
|
||||
threshold=0.1,
|
||||
)
|
||||
scores2 = Scorer.score_cats(
|
||||
[example],
|
||||
"cats",
|
||||
labels=list(gold_doc.cats.keys()),
|
||||
multi_label=False,
|
||||
positive_label="POSITIVE",
|
||||
threshold=0.9,
|
||||
)
|
||||
assert scores1["cats_score"] == 1.0
|
||||
assert scores2["cats_score"] == 1.0
|
||||
assert scores1 == scores2
|
||||
# threshold is relevant for multi_label=True
|
||||
scores = Scorer.score_cats(
|
||||
[example],
|
||||
"cats",
|
||||
labels=list(gold_doc.cats.keys()),
|
||||
multi_label=True,
|
||||
threshold=0.9,
|
||||
)
|
||||
assert scores["cats_macro_f"] == 0.0
|
||||
# threshold is relevant for multi_label=True
|
||||
scores = Scorer.score_cats(
|
||||
[example],
|
||||
"cats",
|
||||
labels=list(gold_doc.cats.keys()),
|
||||
multi_label=True,
|
||||
threshold=0.1,
|
||||
)
|
||||
assert scores["cats_macro_f"] == 0.5
|
||||
|
|
|
@ -229,16 +229,17 @@ The reported `{attr}_score` depends on the classification properties:
|
|||
> print(scores["cats_macro_auc"])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ---------------- | -------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||
| `attr` | The attribute to score. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `getter` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the cats for an individual `Doc`. ~~Callable[[Doc, str], Dict[str, float]]~~ |
|
||||
| labels | The set of possible labels. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||
| `multi_label` | Whether the attribute allows multiple labels. Defaults to `True`. ~~bool~~ |
|
||||
| `positive_label` | The positive label for a binary task with exclusive classes. Defaults to `None`. ~~Optional[str]~~ |
|
||||
| **RETURNS** | A dictionary containing the scores, with inapplicable scores as `None`. ~~Dict[str, Optional[float]]~~ |
|
||||
| Name | Description |
|
||||
| ---------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | The `Example` objects holding both the predictions and the correct gold-standard annotations. ~~Iterable[Example]~~ |
|
||||
| `attr` | The attribute to score. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `getter` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the cats for an individual `Doc`. ~~Callable[[Doc, str], Dict[str, float]]~~ |
|
||||
| labels | The set of possible labels. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||
| `multi_label` | Whether the attribute allows multiple labels. Defaults to `True`. When set to `False` (exclusive labels), missing gold labels are interpreted as `0.0` and the threshold is set to `0.0`. ~~bool~~ |
|
||||
| `positive_label` | The positive label for a binary task with exclusive classes. Defaults to `None`. ~~Optional[str]~~ |
|
||||
| `threshold` | Cutoff to consider a prediction "positive". Defaults to `0.5` for multi-label, and `0.0` (i.e. whatever's highest scoring) otherwise. ~~float~~ |
|
||||
| **RETURNS** | A dictionary containing the scores, with inapplicable scores as `None`. ~~Dict[str, Optional[float]]~~ |
|
||||
|
||||
## Scorer.score_links {#score_links tag="staticmethod" new="3"}
|
||||
|
||||
|
|
|
@ -63,7 +63,6 @@ architectures and their arguments and hyperparameters.
|
|||
> ```python
|
||||
> from spacy.pipeline.textcat import DEFAULT_SINGLE_TEXTCAT_MODEL
|
||||
> config = {
|
||||
> "threshold": 0.5,
|
||||
> "model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
||||
> }
|
||||
> nlp.add_pipe("textcat", config=config)
|
||||
|
@ -82,7 +81,7 @@ architectures and their arguments and hyperparameters.
|
|||
|
||||
| Setting | Description |
|
||||
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant for `textcat_multilabel` when calculating accuracy scores. ~~float~~ |
|
||||
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
| `scorer` | The scoring method. Defaults to [`Scorer.score_cats`](/api/scorer#score_cats) for the attribute `"cats"`. ~~Optional[Callable]~~ |
|
||||
|
||||
|
@ -123,7 +122,7 @@ shortcut for this and instantiate the component using its string name and
|
|||
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant for `textcat_multilabel` when calculating accuracy scores. ~~float~~ |
|
||||
| `scorer` | The scoring method. Defaults to [`Scorer.score_cats`](/api/scorer#score_cats) for the attribute `"cats"`. ~~Optional[Callable]~~ |
|
||||
|
||||
## TextCategorizer.\_\_call\_\_ {#call tag="method"}
|
||||
|
|
Loading…
Reference in New Issue
Block a user