mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Fix Scorer.score_cats for missing labels (#9443)
* Fix Scorer.score_cats for missing labels * Add test case for Scorer.score_cats missing labels * semantic nitpick * black formatting * adjust test to give different results depending on multi_label setting * fix loss function according to whether or not missing values are supported * add note to docs * small fixes * make mypy happy * Update spacy/pipeline/textcat.py Co-authored-by: Florian Cäsar <florian.caesar@pm.me> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: svlandeg <svlandeg@github.com>
This commit is contained in:
parent
b8106e0f95
commit
86e71e7b19
|
@ -1,6 +1,6 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
from itertools import islice
|
||||
from typing import Optional, Callable
|
||||
from itertools import islice
|
||||
|
||||
import srsly
|
||||
from thinc.api import Model, SequenceCategoricalCrossentropy, Config
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import numpy
|
||||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
|
||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||
from thinc.api import Optimizer
|
||||
from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
|
||||
|
||||
import numpy
|
||||
|
||||
from ..compat import Protocol, runtime_checkable
|
||||
from ..scorer import Scorer
|
||||
from ..language import Language
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from itertools import islice
|
||||
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
|
||||
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
||||
from thinc.types import Floats2d
|
||||
import numpy
|
||||
from itertools import islice
|
||||
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ..language import Language
|
||||
|
@ -158,6 +158,13 @@ class TextCategorizer(TrainablePipe):
|
|||
self.cfg = dict(cfg)
|
||||
self.scorer = scorer
|
||||
|
||||
@property
|
||||
def support_missing_values(self):
|
||||
# There are no missing values as the textcat should always
|
||||
# predict exactly one label. All other labels are 0.0
|
||||
# Subclasses may override this property to change internal behaviour.
|
||||
return False
|
||||
|
||||
@property
|
||||
def labels(self) -> Tuple[str]:
|
||||
"""RETURNS (Tuple[str]): The labels currently added to the component.
|
||||
|
@ -294,7 +301,7 @@ class TextCategorizer(TrainablePipe):
|
|||
for j, label in enumerate(self.labels):
|
||||
if label in eg.reference.cats:
|
||||
truths[i, j] = eg.reference.cats[label]
|
||||
else:
|
||||
elif self.support_missing_values:
|
||||
not_missing[i, j] = 0.0
|
||||
truths = self.model.ops.asarray(truths) # type: ignore
|
||||
return truths, not_missing # type: ignore
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from itertools import islice
|
||||
from typing import Iterable, Optional, Dict, List, Callable, Any
|
||||
|
||||
from thinc.api import Model, Config
|
||||
from thinc.types import Floats2d
|
||||
from thinc.api import Model, Config
|
||||
|
||||
from itertools import islice
|
||||
|
||||
from ..language import Language
|
||||
from ..training import Example, validate_get_examples
|
||||
|
@ -158,6 +158,10 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
|||
self.cfg = dict(cfg)
|
||||
self.scorer = scorer
|
||||
|
||||
@property
|
||||
def support_missing_values(self):
|
||||
return True
|
||||
|
||||
def initialize( # type: ignore[override]
|
||||
self,
|
||||
get_examples: Callable[[], Iterable[Example]],
|
||||
|
|
|
@ -445,7 +445,8 @@ class Scorer:
|
|||
getter(doc, attr) should return the values for the individual doc.
|
||||
labels (Iterable[str]): The set of possible labels. Defaults to [].
|
||||
multi_label (bool): Whether the attribute allows multiple labels.
|
||||
Defaults to True.
|
||||
Defaults to True. When set to False (exclusive labels), missing
|
||||
gold labels are interpreted as 0.0.
|
||||
positive_label (str): The positive label for a binary task with
|
||||
exclusive classes. Defaults to None.
|
||||
threshold (float): Cutoff to consider a prediction "positive". Defaults
|
||||
|
@ -484,13 +485,15 @@ class Scorer:
|
|||
|
||||
for label in labels:
|
||||
pred_score = pred_cats.get(label, 0.0)
|
||||
gold_score = gold_cats.get(label, 0.0)
|
||||
gold_score = gold_cats.get(label)
|
||||
if not gold_score and not multi_label:
|
||||
gold_score = 0.0
|
||||
if gold_score is not None:
|
||||
auc_per_type[label].score_set(pred_score, gold_score)
|
||||
if multi_label:
|
||||
for label in labels:
|
||||
pred_score = pred_cats.get(label, 0.0)
|
||||
gold_score = gold_cats.get(label, 0.0)
|
||||
gold_score = gold_cats.get(label)
|
||||
if gold_score is not None:
|
||||
if pred_score >= threshold and gold_score > 0:
|
||||
f_per_type[label].tp += 1
|
||||
|
@ -502,16 +505,15 @@ class Scorer:
|
|||
# Get the highest-scoring for each.
|
||||
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
|
||||
gold_label, gold_score = max(gold_cats.items(), key=lambda it: it[1])
|
||||
if gold_score is not None:
|
||||
if pred_label == gold_label and pred_score >= threshold:
|
||||
f_per_type[pred_label].tp += 1
|
||||
else:
|
||||
f_per_type[gold_label].fn += 1
|
||||
if pred_score >= threshold:
|
||||
f_per_type[pred_label].fp += 1
|
||||
if pred_label == gold_label and pred_score >= threshold:
|
||||
f_per_type[pred_label].tp += 1
|
||||
else:
|
||||
f_per_type[gold_label].fn += 1
|
||||
if pred_score >= threshold:
|
||||
f_per_type[pred_label].fp += 1
|
||||
elif gold_cats:
|
||||
gold_label, gold_score = max(gold_cats, key=lambda it: it[1])
|
||||
if gold_score is not None and gold_score > 0:
|
||||
if gold_score > 0:
|
||||
f_per_type[gold_label].fn += 1
|
||||
elif pred_cats:
|
||||
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
|
||||
|
|
|
@ -725,6 +725,72 @@ def test_textcat_evaluation():
|
|||
assert scores["cats_micro_r"] == 4 / 6
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"multi_label,spring_p",
|
||||
[(True, 1 / 1), (False, 1 / 2)],
|
||||
)
|
||||
def test_textcat_eval_missing(multi_label: bool, spring_p: float):
|
||||
"""
|
||||
multi-label: the missing 'spring' in gold_doc_2 doesn't incur a penalty
|
||||
exclusive labels: the missing 'spring' in gold_doc_2 is interpreted as 0.0"""
|
||||
train_examples = []
|
||||
nlp = English()
|
||||
|
||||
ref1 = nlp("one")
|
||||
ref1.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
|
||||
pred1 = nlp("one")
|
||||
pred1.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
|
||||
train_examples.append(Example(ref1, pred1))
|
||||
|
||||
ref2 = nlp("two")
|
||||
# reference 'spring' is missing, pred 'spring' is 1
|
||||
ref2.cats = {"winter": 0.0, "summer": 0.0, "autumn": 1.0}
|
||||
pred2 = nlp("two")
|
||||
pred2.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
|
||||
train_examples.append(Example(pred2, ref2))
|
||||
|
||||
scores = Scorer().score_cats(
|
||||
train_examples,
|
||||
"cats",
|
||||
labels=["winter", "summer", "spring", "autumn"],
|
||||
multi_label=multi_label,
|
||||
)
|
||||
assert scores["cats_f_per_type"]["spring"]["p"] == spring_p
|
||||
assert scores["cats_f_per_type"]["spring"]["r"] == 1 / 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"multi_label,expected_loss",
|
||||
[(True, 0), (False, 0.125)],
|
||||
)
|
||||
def test_textcat_loss(multi_label: bool, expected_loss: float):
|
||||
"""
|
||||
multi-label: the missing 'spring' in gold_doc_2 doesn't incur an increase in loss
|
||||
exclusive labels: the missing 'spring' in gold_doc_2 is interpreted as 0.0 and adds to the loss"""
|
||||
train_examples = []
|
||||
nlp = English()
|
||||
|
||||
doc1 = nlp("one")
|
||||
cats1 = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
|
||||
train_examples.append(Example.from_dict(doc1, {"cats": cats1}))
|
||||
|
||||
doc2 = nlp("two")
|
||||
cats2 = {"winter": 0.0, "summer": 0.0, "autumn": 1.0}
|
||||
train_examples.append(Example.from_dict(doc2, {"cats": cats2}))
|
||||
|
||||
if multi_label:
|
||||
textcat = nlp.add_pipe("textcat_multilabel")
|
||||
else:
|
||||
textcat = nlp.add_pipe("textcat")
|
||||
textcat.initialize(lambda: train_examples)
|
||||
assert isinstance(textcat, TextCategorizer)
|
||||
scores = textcat.model.ops.asarray(
|
||||
[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]], dtype="f" # type: ignore
|
||||
)
|
||||
loss, d_scores = textcat.get_loss(train_examples, scores)
|
||||
assert loss == expected_loss
|
||||
|
||||
|
||||
def test_textcat_threshold():
|
||||
# Ensure the scorer can be called with a different threshold
|
||||
nlp = English()
|
||||
|
|
|
@ -34,7 +34,11 @@ only.
|
|||
Predictions will be saved to `doc.cats` as a dictionary, where the key is the
|
||||
name of the category and the value is a score between 0 and 1 (inclusive). For
|
||||
`textcat` (exclusive categories), the scores will sum to 1, while for
|
||||
`textcat_multilabel` there is no particular guarantee about their sum.
|
||||
`textcat_multilabel` there is no particular guarantee about their sum. This also
|
||||
means that for `textcat`, missing values are equated to a value of 0 (i.e.
|
||||
`False`) and are counted as such towards the loss and scoring metrics. This is
|
||||
not the case for `textcat_multilabel`, where missing values in the gold standard
|
||||
data do not influence the loss or accuracy calculations.
|
||||
|
||||
Note that when assigning values to create training data, the score of each
|
||||
category must be 0 or 1. Using other values, for example to create a document
|
||||
|
|
Loading…
Reference in New Issue
Block a user