From 75bb7ad541a94c74127b57ffd6d674841767478c Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Thu, 17 Nov 2022 18:25:01 +0900 Subject: [PATCH] Check textcat values for validity (#11763) * Check textcat values for validity * Fix error numbers * Clean up vals reference * Check category value validity through training The _validate_categories is called in update, which for multilabel is inherited from the single label component. * Formatting --- spacy/errors.py | 2 ++ spacy/pipeline/textcat.py | 10 +++++++--- spacy/pipeline/textcat_multilabel.py | 8 +++++++- spacy/tests/pipeline/test_textcat.py | 24 ++++++++++++++++++++++++ 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 278e5496a..1d29f0e17 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -544,6 +544,8 @@ class Errors(metaclass=ErrorsWithCodes): "during training, make sure to include it in 'annotating components'") # New errors added in v3.x + E851 = ("The 'textcat' component labels should only have values of 0 or 1, " + "but found value of '{val}'.") E852 = ("The tar file pulled from the remote attempted an unsafe path " "traversal.") E853 = ("Unsupported component factory name '{name}'. The character '.' is " diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 4023c4456..a86eb99d2 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -293,7 +293,7 @@ class TextCategorizer(TrainablePipe): bp_scores(gradient) if sgd is not None: self.finish_update(sgd) - losses[self.name] += (gradient**2).sum() + losses[self.name] += (gradient ** 2).sum() return losses def _examples_to_truth( @@ -327,7 +327,7 @@ class TextCategorizer(TrainablePipe): not_missing = self.model.ops.asarray(not_missing) # type: ignore d_scores = scores - truths d_scores *= not_missing - mean_square_error = (d_scores**2).mean() + mean_square_error = (d_scores ** 2).mean() return float(mean_square_error), d_scores def add_label(self, label: str) -> int: @@ -401,5 +401,9 @@ class TextCategorizer(TrainablePipe): def _validate_categories(self, examples: Iterable[Example]): """Check whether the provided examples all have single-label cats annotations.""" for ex in examples: - if list(ex.reference.cats.values()).count(1.0) > 1: + vals = list(ex.reference.cats.values()) + if vals.count(1.0) > 1: raise ValueError(Errors.E895.format(value=ex.reference.cats)) + for val in vals: + if not (val == 1.0 or val == 0.0): + raise ValueError(Errors.E851.format(val=val)) diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index eb83d9cb7..ef9bd6557 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -192,6 +192,8 @@ class MultiLabel_TextCategorizer(TextCategorizer): for label in labels: self.add_label(label) subbatch = list(islice(get_examples(), 10)) + self._validate_categories(subbatch) + doc_sample = [eg.reference for eg in subbatch] label_sample, _ = self._examples_to_truth(subbatch) self._require_labels() @@ -202,4 +204,8 @@ class MultiLabel_TextCategorizer(TextCategorizer): def _validate_categories(self, examples: Iterable[Example]): """This component allows any type of single- or multi-label annotations. This method overwrites the more strict one from 'textcat'.""" - pass + # check that annotation values are valid + for ex in examples: + for val in ex.reference.cats.values(): + if not (val == 1.0 or val == 0.0): + raise ValueError(Errors.E851.format(val=val)) diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index d359b77db..2eda9deaf 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -360,6 +360,30 @@ def test_label_types(name): nlp.initialize() +@pytest.mark.parametrize( + "name,get_examples", + [ + ("textcat", make_get_examples_single_label), + ("textcat_multilabel", make_get_examples_multi_label), + ], +) +def test_invalid_label_value(name, get_examples): + nlp = Language() + textcat = nlp.add_pipe(name) + example_getter = get_examples(nlp) + + def invalid_examples(): + # make one example with an invalid score + examples = example_getter() + ref = examples[0].reference + key = list(ref.cats.keys())[0] + ref.cats[key] = 2.0 + return examples + + with pytest.raises(ValueError): + nlp.initialize(get_examples=invalid_examples) + + @pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"]) def test_no_label(name): nlp = Language()