Fix Scorer.score_cats for missing labels (#9443)

* Fix Scorer.score_cats for missing labels

* Add test case for Scorer.score_cats missing labels

* semantic nitpick

* black formatting

* adjust test to give different results depending on multi_label setting

* fix loss function according to whether or not missing values are supported

* add note to docs

* small fixes

* make mypy happy

* Update spacy/pipeline/textcat.py

Co-authored-by: Florian Cäsar <florian.caesar@pm.me>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Co-authored-by: svlandeg <svlandeg@github.com>
This commit is contained in:
Florian Cäsar 2021-12-29 11:04:39 +01:00 committed by GitHub
parent b8106e0f95
commit 86e71e7b19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 103 additions and 19 deletions

View File

@ -1,6 +1,6 @@
# cython: infer_types=True, profile=True, binding=True # cython: infer_types=True, profile=True, binding=True
from itertools import islice
from typing import Optional, Callable from typing import Optional, Callable
from itertools import islice
import srsly import srsly
from thinc.api import Model, SequenceCategoricalCrossentropy, Config from thinc.api import Model, SequenceCategoricalCrossentropy, Config

View File

@ -1,9 +1,10 @@
import numpy
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer from thinc.api import Optimizer
from thinc.types import Ragged, Ints2d, Floats2d, Ints1d from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
import numpy
from ..compat import Protocol, runtime_checkable from ..compat import Protocol, runtime_checkable
from ..scorer import Scorer from ..scorer import Scorer
from ..language import Language from ..language import Language

View File

@ -1,8 +1,8 @@
from itertools import islice
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
from thinc.types import Floats2d from thinc.types import Floats2d
import numpy import numpy
from itertools import islice
from .trainable_pipe import TrainablePipe from .trainable_pipe import TrainablePipe
from ..language import Language from ..language import Language
@ -158,6 +158,13 @@ class TextCategorizer(TrainablePipe):
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.scorer = scorer self.scorer = scorer
@property
def support_missing_values(self):
# There are no missing values as the textcat should always
# predict exactly one label. All other labels are 0.0
# Subclasses may override this property to change internal behaviour.
return False
@property @property
def labels(self) -> Tuple[str]: def labels(self) -> Tuple[str]:
"""RETURNS (Tuple[str]): The labels currently added to the component. """RETURNS (Tuple[str]): The labels currently added to the component.
@ -294,7 +301,7 @@ class TextCategorizer(TrainablePipe):
for j, label in enumerate(self.labels): for j, label in enumerate(self.labels):
if label in eg.reference.cats: if label in eg.reference.cats:
truths[i, j] = eg.reference.cats[label] truths[i, j] = eg.reference.cats[label]
else: elif self.support_missing_values:
not_missing[i, j] = 0.0 not_missing[i, j] = 0.0
truths = self.model.ops.asarray(truths) # type: ignore truths = self.model.ops.asarray(truths) # type: ignore
return truths, not_missing # type: ignore return truths, not_missing # type: ignore

View File

@ -1,8 +1,8 @@
from itertools import islice
from typing import Iterable, Optional, Dict, List, Callable, Any from typing import Iterable, Optional, Dict, List, Callable, Any
from thinc.api import Model, Config
from thinc.types import Floats2d from thinc.types import Floats2d
from thinc.api import Model, Config
from itertools import islice
from ..language import Language from ..language import Language
from ..training import Example, validate_get_examples from ..training import Example, validate_get_examples
@ -158,6 +158,10 @@ class MultiLabel_TextCategorizer(TextCategorizer):
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.scorer = scorer self.scorer = scorer
@property
def support_missing_values(self):
return True
def initialize( # type: ignore[override] def initialize( # type: ignore[override]
self, self,
get_examples: Callable[[], Iterable[Example]], get_examples: Callable[[], Iterable[Example]],

View File

@ -445,7 +445,8 @@ class Scorer:
getter(doc, attr) should return the values for the individual doc. getter(doc, attr) should return the values for the individual doc.
labels (Iterable[str]): The set of possible labels. Defaults to []. labels (Iterable[str]): The set of possible labels. Defaults to [].
multi_label (bool): Whether the attribute allows multiple labels. multi_label (bool): Whether the attribute allows multiple labels.
Defaults to True. Defaults to True. When set to False (exclusive labels), missing
gold labels are interpreted as 0.0.
positive_label (str): The positive label for a binary task with positive_label (str): The positive label for a binary task with
exclusive classes. Defaults to None. exclusive classes. Defaults to None.
threshold (float): Cutoff to consider a prediction "positive". Defaults threshold (float): Cutoff to consider a prediction "positive". Defaults
@ -484,13 +485,15 @@ class Scorer:
for label in labels: for label in labels:
pred_score = pred_cats.get(label, 0.0) pred_score = pred_cats.get(label, 0.0)
gold_score = gold_cats.get(label, 0.0) gold_score = gold_cats.get(label)
if not gold_score and not multi_label:
gold_score = 0.0
if gold_score is not None: if gold_score is not None:
auc_per_type[label].score_set(pred_score, gold_score) auc_per_type[label].score_set(pred_score, gold_score)
if multi_label: if multi_label:
for label in labels: for label in labels:
pred_score = pred_cats.get(label, 0.0) pred_score = pred_cats.get(label, 0.0)
gold_score = gold_cats.get(label, 0.0) gold_score = gold_cats.get(label)
if gold_score is not None: if gold_score is not None:
if pred_score >= threshold and gold_score > 0: if pred_score >= threshold and gold_score > 0:
f_per_type[label].tp += 1 f_per_type[label].tp += 1
@ -502,7 +505,6 @@ class Scorer:
# Get the highest-scoring for each. # Get the highest-scoring for each.
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1]) pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
gold_label, gold_score = max(gold_cats.items(), key=lambda it: it[1]) gold_label, gold_score = max(gold_cats.items(), key=lambda it: it[1])
if gold_score is not None:
if pred_label == gold_label and pred_score >= threshold: if pred_label == gold_label and pred_score >= threshold:
f_per_type[pred_label].tp += 1 f_per_type[pred_label].tp += 1
else: else:
@ -511,7 +513,7 @@ class Scorer:
f_per_type[pred_label].fp += 1 f_per_type[pred_label].fp += 1
elif gold_cats: elif gold_cats:
gold_label, gold_score = max(gold_cats, key=lambda it: it[1]) gold_label, gold_score = max(gold_cats, key=lambda it: it[1])
if gold_score is not None and gold_score > 0: if gold_score > 0:
f_per_type[gold_label].fn += 1 f_per_type[gold_label].fn += 1
elif pred_cats: elif pred_cats:
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1]) pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])

View File

@ -725,6 +725,72 @@ def test_textcat_evaluation():
assert scores["cats_micro_r"] == 4 / 6 assert scores["cats_micro_r"] == 4 / 6
@pytest.mark.parametrize(
"multi_label,spring_p",
[(True, 1 / 1), (False, 1 / 2)],
)
def test_textcat_eval_missing(multi_label: bool, spring_p: float):
"""
multi-label: the missing 'spring' in gold_doc_2 doesn't incur a penalty
exclusive labels: the missing 'spring' in gold_doc_2 is interpreted as 0.0"""
train_examples = []
nlp = English()
ref1 = nlp("one")
ref1.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
pred1 = nlp("one")
pred1.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
train_examples.append(Example(ref1, pred1))
ref2 = nlp("two")
# reference 'spring' is missing, pred 'spring' is 1
ref2.cats = {"winter": 0.0, "summer": 0.0, "autumn": 1.0}
pred2 = nlp("two")
pred2.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
train_examples.append(Example(pred2, ref2))
scores = Scorer().score_cats(
train_examples,
"cats",
labels=["winter", "summer", "spring", "autumn"],
multi_label=multi_label,
)
assert scores["cats_f_per_type"]["spring"]["p"] == spring_p
assert scores["cats_f_per_type"]["spring"]["r"] == 1 / 1
@pytest.mark.parametrize(
"multi_label,expected_loss",
[(True, 0), (False, 0.125)],
)
def test_textcat_loss(multi_label: bool, expected_loss: float):
"""
multi-label: the missing 'spring' in gold_doc_2 doesn't incur an increase in loss
exclusive labels: the missing 'spring' in gold_doc_2 is interpreted as 0.0 and adds to the loss"""
train_examples = []
nlp = English()
doc1 = nlp("one")
cats1 = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
train_examples.append(Example.from_dict(doc1, {"cats": cats1}))
doc2 = nlp("two")
cats2 = {"winter": 0.0, "summer": 0.0, "autumn": 1.0}
train_examples.append(Example.from_dict(doc2, {"cats": cats2}))
if multi_label:
textcat = nlp.add_pipe("textcat_multilabel")
else:
textcat = nlp.add_pipe("textcat")
textcat.initialize(lambda: train_examples)
assert isinstance(textcat, TextCategorizer)
scores = textcat.model.ops.asarray(
[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]], dtype="f" # type: ignore
)
loss, d_scores = textcat.get_loss(train_examples, scores)
assert loss == expected_loss
def test_textcat_threshold(): def test_textcat_threshold():
# Ensure the scorer can be called with a different threshold # Ensure the scorer can be called with a different threshold
nlp = English() nlp = English()

View File

@ -34,7 +34,11 @@ only.
Predictions will be saved to `doc.cats` as a dictionary, where the key is the Predictions will be saved to `doc.cats` as a dictionary, where the key is the
name of the category and the value is a score between 0 and 1 (inclusive). For name of the category and the value is a score between 0 and 1 (inclusive). For
`textcat` (exclusive categories), the scores will sum to 1, while for `textcat` (exclusive categories), the scores will sum to 1, while for
`textcat_multilabel` there is no particular guarantee about their sum. `textcat_multilabel` there is no particular guarantee about their sum. This also
means that for `textcat`, missing values are equated to a value of 0 (i.e.
`False`) and are counted as such towards the loss and scoring metrics. This is
not the case for `textcat_multilabel`, where missing values in the gold standard
data do not influence the loss or accuracy calculations.
Note that when assigning values to create training data, the score of each Note that when assigning values to create training data, the score of each
category must be 0 or 1. Using other values, for example to create a document category must be 0 or 1. Using other values, for example to create a document