diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index 54ce021af..2e0f364f0 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -1,6 +1,6 @@ # cython: infer_types=True, profile=True, binding=True -from itertools import islice from typing import Optional, Callable +from itertools import islice import srsly from thinc.api import Model, SequenceCategoricalCrossentropy, Config diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 829def1eb..01c9c407f 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -1,9 +1,10 @@ -import numpy from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Optimizer from thinc.types import Ragged, Ints2d, Floats2d, Ints1d +import numpy + from ..compat import Protocol, runtime_checkable from ..scorer import Scorer from ..language import Language diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 30a65ec52..e20ae87f1 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -1,8 +1,8 @@ -from itertools import islice from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config from thinc.types import Floats2d import numpy +from itertools import islice from .trainable_pipe import TrainablePipe from ..language import Language @@ -158,6 +158,13 @@ class TextCategorizer(TrainablePipe): self.cfg = dict(cfg) self.scorer = scorer + @property + def support_missing_values(self): + # There are no missing values as the textcat should always + # predict exactly one label. All other labels are 0.0 + # Subclasses may override this property to change internal behaviour. + return False + @property def labels(self) -> Tuple[str]: """RETURNS (Tuple[str]): The labels currently added to the component. @@ -294,7 +301,7 @@ class TextCategorizer(TrainablePipe): for j, label in enumerate(self.labels): if label in eg.reference.cats: truths[i, j] = eg.reference.cats[label] - else: + elif self.support_missing_values: not_missing[i, j] = 0.0 truths = self.model.ops.asarray(truths) # type: ignore return truths, not_missing # type: ignore diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index a7bfacca7..e33a885f8 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -1,8 +1,8 @@ -from itertools import islice from typing import Iterable, Optional, Dict, List, Callable, Any - -from thinc.api import Model, Config from thinc.types import Floats2d +from thinc.api import Model, Config + +from itertools import islice from ..language import Language from ..training import Example, validate_get_examples @@ -158,6 +158,10 @@ class MultiLabel_TextCategorizer(TextCategorizer): self.cfg = dict(cfg) self.scorer = scorer + @property + def support_missing_values(self): + return True + def initialize( # type: ignore[override] self, get_examples: Callable[[], Iterable[Example]], diff --git a/spacy/scorer.py b/spacy/scorer.py index 4d596b5e1..ae9338bd5 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -445,7 +445,8 @@ class Scorer: getter(doc, attr) should return the values for the individual doc. labels (Iterable[str]): The set of possible labels. Defaults to []. multi_label (bool): Whether the attribute allows multiple labels. - Defaults to True. + Defaults to True. When set to False (exclusive labels), missing + gold labels are interpreted as 0.0. positive_label (str): The positive label for a binary task with exclusive classes. Defaults to None. threshold (float): Cutoff to consider a prediction "positive". Defaults @@ -484,13 +485,15 @@ class Scorer: for label in labels: pred_score = pred_cats.get(label, 0.0) - gold_score = gold_cats.get(label, 0.0) + gold_score = gold_cats.get(label) + if not gold_score and not multi_label: + gold_score = 0.0 if gold_score is not None: auc_per_type[label].score_set(pred_score, gold_score) if multi_label: for label in labels: pred_score = pred_cats.get(label, 0.0) - gold_score = gold_cats.get(label, 0.0) + gold_score = gold_cats.get(label) if gold_score is not None: if pred_score >= threshold and gold_score > 0: f_per_type[label].tp += 1 @@ -502,16 +505,15 @@ class Scorer: # Get the highest-scoring for each. pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1]) gold_label, gold_score = max(gold_cats.items(), key=lambda it: it[1]) - if gold_score is not None: - if pred_label == gold_label and pred_score >= threshold: - f_per_type[pred_label].tp += 1 - else: - f_per_type[gold_label].fn += 1 - if pred_score >= threshold: - f_per_type[pred_label].fp += 1 + if pred_label == gold_label and pred_score >= threshold: + f_per_type[pred_label].tp += 1 + else: + f_per_type[gold_label].fn += 1 + if pred_score >= threshold: + f_per_type[pred_label].fp += 1 elif gold_cats: gold_label, gold_score = max(gold_cats, key=lambda it: it[1]) - if gold_score is not None and gold_score > 0: + if gold_score > 0: f_per_type[gold_label].fn += 1 elif pred_cats: pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1]) diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 282789f2b..52bf6ec5c 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -725,6 +725,72 @@ def test_textcat_evaluation(): assert scores["cats_micro_r"] == 4 / 6 +@pytest.mark.parametrize( + "multi_label,spring_p", + [(True, 1 / 1), (False, 1 / 2)], +) +def test_textcat_eval_missing(multi_label: bool, spring_p: float): + """ + multi-label: the missing 'spring' in gold_doc_2 doesn't incur a penalty + exclusive labels: the missing 'spring' in gold_doc_2 is interpreted as 0.0""" + train_examples = [] + nlp = English() + + ref1 = nlp("one") + ref1.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0} + pred1 = nlp("one") + pred1.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0} + train_examples.append(Example(ref1, pred1)) + + ref2 = nlp("two") + # reference 'spring' is missing, pred 'spring' is 1 + ref2.cats = {"winter": 0.0, "summer": 0.0, "autumn": 1.0} + pred2 = nlp("two") + pred2.cats = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0} + train_examples.append(Example(pred2, ref2)) + + scores = Scorer().score_cats( + train_examples, + "cats", + labels=["winter", "summer", "spring", "autumn"], + multi_label=multi_label, + ) + assert scores["cats_f_per_type"]["spring"]["p"] == spring_p + assert scores["cats_f_per_type"]["spring"]["r"] == 1 / 1 + + +@pytest.mark.parametrize( + "multi_label,expected_loss", + [(True, 0), (False, 0.125)], +) +def test_textcat_loss(multi_label: bool, expected_loss: float): + """ + multi-label: the missing 'spring' in gold_doc_2 doesn't incur an increase in loss + exclusive labels: the missing 'spring' in gold_doc_2 is interpreted as 0.0 and adds to the loss""" + train_examples = [] + nlp = English() + + doc1 = nlp("one") + cats1 = {"winter": 0.0, "summer": 0.0, "autumn": 0.0, "spring": 1.0} + train_examples.append(Example.from_dict(doc1, {"cats": cats1})) + + doc2 = nlp("two") + cats2 = {"winter": 0.0, "summer": 0.0, "autumn": 1.0} + train_examples.append(Example.from_dict(doc2, {"cats": cats2})) + + if multi_label: + textcat = nlp.add_pipe("textcat_multilabel") + else: + textcat = nlp.add_pipe("textcat") + textcat.initialize(lambda: train_examples) + assert isinstance(textcat, TextCategorizer) + scores = textcat.model.ops.asarray( + [[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]], dtype="f" # type: ignore + ) + loss, d_scores = textcat.get_loss(train_examples, scores) + assert loss == expected_loss + + def test_textcat_threshold(): # Ensure the scorer can be called with a different threshold nlp = English() diff --git a/website/docs/api/textcategorizer.md b/website/docs/api/textcategorizer.md index 47f868637..2ff569bad 100644 --- a/website/docs/api/textcategorizer.md +++ b/website/docs/api/textcategorizer.md @@ -34,7 +34,11 @@ only. Predictions will be saved to `doc.cats` as a dictionary, where the key is the name of the category and the value is a score between 0 and 1 (inclusive). For `textcat` (exclusive categories), the scores will sum to 1, while for -`textcat_multilabel` there is no particular guarantee about their sum. +`textcat_multilabel` there is no particular guarantee about their sum. This also +means that for `textcat`, missing values are equated to a value of 0 (i.e. +`False`) and are counted as such towards the loss and scoring metrics. This is +not the case for `textcat_multilabel`, where missing values in the gold standard +data do not influence the loss or accuracy calculations. Note that when assigning values to create training data, the score of each category must be 0 or 1. Using other values, for example to create a document