mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Fix Scorer.score_cats for missing labels (#9443)
* Fix Scorer.score_cats for missing labels * Add test case for Scorer.score_cats missing labels * semantic nitpick * black formatting * adjust test to give different results depending on multi_label setting * fix loss function according to whether or not missing values are supported * add note to docs * small fixes * make mypy happy * Update spacy/pipeline/textcat.py Co-authored-by: Florian Cäsar <florian.caesar@pm.me> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: svlandeg <svlandeg@github.com>
This commit is contained in:
parent
b8106e0f95
commit
86e71e7b19
|
@ -1,6 +1,6 @@
|
||||||
# cython: infer_types=True, profile=True, binding=True
|
# cython: infer_types=True, profile=True, binding=True
|
||||||
from itertools import islice
|
|
||||||
from typing import Optional, Callable
|
from typing import Optional, Callable
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import Model, SequenceCategoricalCrossentropy, Config
|
from thinc.api import Model, SequenceCategoricalCrossentropy, Config
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import numpy
|
|
||||||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
|
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
|
||||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||||
from thinc.api import Optimizer
|
from thinc.api import Optimizer
|
||||||
from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
|
from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
|
||||||
from ..compat import Protocol, runtime_checkable
|
from ..compat import Protocol, runtime_checkable
|
||||||
from ..scorer import Scorer
|
from ..scorer import Scorer
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from itertools import islice
|
|
||||||
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
|
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
|
||||||
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
import numpy
|
import numpy
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe import TrainablePipe
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
|
@ -158,6 +158,13 @@ class TextCategorizer(TrainablePipe):
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def support_missing_values(self):
|
||||||
|
# There are no missing values as the textcat should always
|
||||||
|
# predict exactly one label. All other labels are 0.0
|
||||||
|
# Subclasses may override this property to change internal behaviour.
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self) -> Tuple[str]:
|
def labels(self) -> Tuple[str]:
|
||||||
"""RETURNS (Tuple[str]): The labels currently added to the component.
|
"""RETURNS (Tuple[str]): The labels currently added to the component.
|
||||||
|
@ -294,7 +301,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
for j, label in enumerate(self.labels):
|
for j, label in enumerate(self.labels):
|
||||||
if label in eg.reference.cats:
|
if label in eg.reference.cats:
|
||||||
truths[i, j] = eg.reference.cats[label]
|
truths[i, j] = eg.reference.cats[label]
|
||||||
else:
|
elif self.support_missing_values:
|
||||||
not_missing[i, j] = 0.0
|
not_missing[i, j] = 0.0
|
||||||
truths = self.model.ops.asarray(truths) # type: ignore
|
truths = self.model.ops.asarray(truths) # type: ignore
|
||||||
return truths, not_missing # type: ignore
|
return truths, not_missing # type: ignore
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from itertools import islice
|
|
||||||
from typing import Iterable, Optional, Dict, List, Callable, Any
|
from typing import Iterable, Optional, Dict, List, Callable, Any
|
||||||
|
|
||||||
from thinc.api import Model, Config
|
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
|
from thinc.api import Model, Config
|
||||||
|
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..training import Example, validate_get_examples
|
from ..training import Example, validate_get_examples
|
||||||
|
@ -158,6 +158,10 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def support_missing_values(self):
|
||||||
|
return True
|
||||||
|
|
||||||
def initialize( # type: ignore[override]
|
def initialize( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
get_examples: Callable[[], Iterable[Example]],
|
get_examples: Callable[[], Iterable[Example]],
|
||||||
|
|
|
@ -445,7 +445,8 @@ class Scorer:
|
||||||
getter(doc, attr) should return the values for the individual doc.
|
getter(doc, attr) should return the values for the individual doc.
|
||||||
labels (Iterable[str]): The set of possible labels. Defaults to [].
|
labels (Iterable[str]): The set of possible labels. Defaults to [].
|
||||||
multi_label (bool): Whether the attribute allows multiple labels.
|
multi_label (bool): Whether the attribute allows multiple labels.
|
||||||
Defaults to True.
|
Defaults to True. When set to False (exclusive labels), missing
|
||||||
|
gold labels are interpreted as 0.0.
|
||||||
positive_label (str): The positive label for a binary task with
|
positive_label (str): The positive label for a binary task with
|
||||||
exclusive classes. Defaults to None.
|
exclusive classes. Defaults to None.
|
||||||
threshold (float): Cutoff to consider a prediction "positive". Defaults
|
threshold (float): Cutoff to consider a prediction "positive". Defaults
|
||||||
|
@ -484,13 +485,15 @@ class Scorer:
|
||||||
|
|
||||||
for label in labels:
|
for label in labels:
|
||||||
pred_score = pred_cats.get(label, 0.0)
|
pred_score = pred_cats.get(label, 0.0)
|
||||||
gold_score = gold_cats.get(label, 0.0)
|
gold_score = gold_cats.get(label)
|
||||||
|
if not gold_score and not multi_label:
|
||||||
|
gold_score = 0.0
|
||||||
if gold_score is not None:
|
if gold_score is not None:
|
||||||
auc_per_type[label].score_set(pred_score, gold_score)
|
auc_per_type[label].score_set(pred_score, gold_score)
|
||||||
if multi_label:
|
if multi_label:
|
||||||
for label in labels:
|
for label in labels:
|
||||||
pred_score = pred_cats.get(label, 0.0)
|
pred_score = pred_cats.get(label, 0.0)
|
||||||
gold_score = gold_cats.get(label, 0.0)
|
gold_score = gold_cats.get(label)
|
||||||
if gold_score is not None:
|
if gold_score is not None:
|
||||||
if pred_score >= threshold and gold_score > 0:
|
if pred_score >= threshold and gold_score > 0:
|
||||||
f_per_type[label].tp += 1
|
f_per_type[label].tp += 1
|
||||||
|
@ -502,7 +505,6 @@ class Scorer:
|
||||||
# Get the highest-scoring for each.
|
# Get the highest-scoring for each.
|
||||||
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
|
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
|
||||||
gold_label, gold_score = max(gold_cats.items(), key=lambda it: it[1])
|
gold_label, gold_score = max(gold_cats.items(), key=lambda it: it[1])
|
||||||
if gold_score is not None:
|
|
||||||
if pred_label == gold_label and pred_score >= threshold:
|
if pred_label == gold_label and pred_score >= threshold:
|
||||||
f_per_type[pred_label].tp += 1
|
f_per_type[pred_label].tp += 1
|
||||||
else:
|
else:
|
||||||
|
@ -511,7 +513,7 @@ class Scorer:
|
||||||
f_per_type[pred_label].fp += 1
|
f_per_type[pred_label].fp += 1
|
||||||
elif gold_cats:
|
elif gold_cats:
|
||||||
gold_label, gold_score = max(gold_cats, key=lambda it: it[1])
|
gold_label, gold_score = max(gold_cats, key=lambda it: it[1])
|
||||||
if gold_score is not None and gold_score > 0:
|
if gold_score > 0:
|
||||||
f_per_type[gold_label].fn += 1
|
f_per_type[gold_label].fn += 1
|
||||||
elif pred_cats:
|
elif pred_cats:
|
||||||
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
|
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
|
||||||
|
|
|
@ -725,6 +725,72 @@ def test_textcat_evaluation():
|
||||||
assert scores["cats_micro_r"] == 4 / 6
|
assert scores["cats_micro_r"] == 4 / 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"multi_label,spring_p",
|
||||||
|
[(True, 1 / 1), (False, 1 / 2)],
|
||||||
|
)
|
||||||
|
def test_textcat_eval_missing(multi_label: bool, spring_p: float):
|
||||||
|
"""
|
||||||
|
multi-label: the missing 'spring' in gold_doc_2 doesn't incur a penalty
|
||||||
|
exclusive labels: the missing 'spring' in gold_doc_2 is interpreted as 0.0"""
|
||||||
|
train_examples = []
|
||||||
|
nlp = English()
|
||||||
|
|
||||||
|
ref1 = nlp("one")
|
||||||
|
ref1.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
|
||||||
|
pred1 = nlp("one")
|
||||||
|
pred1.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
|
||||||
|
train_examples.append(Example(ref1, pred1))
|
||||||
|
|
||||||
|
ref2 = nlp("two")
|
||||||
|
# reference 'spring' is missing, pred 'spring' is 1
|
||||||
|
ref2.cats = {"winter": 0.0, "summer": 0.0, "autumn": 1.0}
|
||||||
|
pred2 = nlp("two")
|
||||||
|
pred2.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
|
||||||
|
train_examples.append(Example(pred2, ref2))
|
||||||
|
|
||||||
|
scores = Scorer().score_cats(
|
||||||
|
train_examples,
|
||||||
|
"cats",
|
||||||
|
labels=["winter", "summer", "spring", "autumn"],
|
||||||
|
multi_label=multi_label,
|
||||||
|
)
|
||||||
|
assert scores["cats_f_per_type"]["spring"]["p"] == spring_p
|
||||||
|
assert scores["cats_f_per_type"]["spring"]["r"] == 1 / 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"multi_label,expected_loss",
|
||||||
|
[(True, 0), (False, 0.125)],
|
||||||
|
)
|
||||||
|
def test_textcat_loss(multi_label: bool, expected_loss: float):
|
||||||
|
"""
|
||||||
|
multi-label: the missing 'spring' in gold_doc_2 doesn't incur an increase in loss
|
||||||
|
exclusive labels: the missing 'spring' in gold_doc_2 is interpreted as 0.0 and adds to the loss"""
|
||||||
|
train_examples = []
|
||||||
|
nlp = English()
|
||||||
|
|
||||||
|
doc1 = nlp("one")
|
||||||
|
cats1 = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0}
|
||||||
|
train_examples.append(Example.from_dict(doc1, {"cats": cats1}))
|
||||||
|
|
||||||
|
doc2 = nlp("two")
|
||||||
|
cats2 = {"winter": 0.0, "summer": 0.0, "autumn": 1.0}
|
||||||
|
train_examples.append(Example.from_dict(doc2, {"cats": cats2}))
|
||||||
|
|
||||||
|
if multi_label:
|
||||||
|
textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
else:
|
||||||
|
textcat = nlp.add_pipe("textcat")
|
||||||
|
textcat.initialize(lambda: train_examples)
|
||||||
|
assert isinstance(textcat, TextCategorizer)
|
||||||
|
scores = textcat.model.ops.asarray(
|
||||||
|
[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]], dtype="f" # type: ignore
|
||||||
|
)
|
||||||
|
loss, d_scores = textcat.get_loss(train_examples, scores)
|
||||||
|
assert loss == expected_loss
|
||||||
|
|
||||||
|
|
||||||
def test_textcat_threshold():
|
def test_textcat_threshold():
|
||||||
# Ensure the scorer can be called with a different threshold
|
# Ensure the scorer can be called with a different threshold
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
|
|
@ -34,7 +34,11 @@ only.
|
||||||
Predictions will be saved to `doc.cats` as a dictionary, where the key is the
|
Predictions will be saved to `doc.cats` as a dictionary, where the key is the
|
||||||
name of the category and the value is a score between 0 and 1 (inclusive). For
|
name of the category and the value is a score between 0 and 1 (inclusive). For
|
||||||
`textcat` (exclusive categories), the scores will sum to 1, while for
|
`textcat` (exclusive categories), the scores will sum to 1, while for
|
||||||
`textcat_multilabel` there is no particular guarantee about their sum.
|
`textcat_multilabel` there is no particular guarantee about their sum. This also
|
||||||
|
means that for `textcat`, missing values are equated to a value of 0 (i.e.
|
||||||
|
`False`) and are counted as such towards the loss and scoring metrics. This is
|
||||||
|
not the case for `textcat_multilabel`, where missing values in the gold standard
|
||||||
|
data do not influence the loss or accuracy calculations.
|
||||||
|
|
||||||
Note that when assigning values to create training data, the score of each
|
Note that when assigning values to create training data, the score of each
|
||||||
category must be 0 or 1. Using other values, for example to create a document
|
category must be 0 or 1. Using other values, for example to create a document
|
||||||
|
|
Loading…
Reference in New Issue
Block a user