From 028967617731b0961afe5116dc5d4168160a6c71 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Mon, 7 Nov 2022 19:01:30 +0900 Subject: [PATCH] Check textcat values for validity --- spacy/errors.py | 2 ++ spacy/pipeline/textcat.py | 3 +++ spacy/pipeline/textcat_multilabel.py | 6 ++++++ spacy/tests/pipeline/test_textcat.py | 24 ++++++++++++++++++++++++ 4 files changed, 35 insertions(+) diff --git a/spacy/errors.py b/spacy/errors.py index e0628819d..00b5221f1 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -544,6 +544,8 @@ class Errors(metaclass=ErrorsWithCodes): "during training, make sure to include it in 'annotating components'") # New errors added in v3.x + E852 = ("The 'textcat' component labels should only have values of 0 or 1, " + "but found value of '{val}'.") E853 = ("Unsupported component factory name '{name}'. The character '.' is " "not permitted in factory names.") E854 = ("Unable to set doc.ents. Check that the 'ents_filter' does not " diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 4023c4456..76726d436 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -403,3 +403,6 @@ class TextCategorizer(TrainablePipe): for ex in examples: if list(ex.reference.cats.values()).count(1.0) > 1: raise ValueError(Errors.E895.format(value=ex.reference.cats)) + for val in ex.reference.cats.values(): + if not (val == 1.0 or val == 0.0): + raise ValueError(Errors.E852.format(val=val)) diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index eb83d9cb7..b777e3bec 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -192,6 +192,12 @@ class MultiLabel_TextCategorizer(TextCategorizer): for label in labels: self.add_label(label) subbatch = list(islice(get_examples(), 10)) + + # check that annotation values are valid + for eg in subbatch: + for val in eg.reference.cats.values(): + if not (val == 1.0 or val == 0.0): + raise ValueError(Errors.E852.format(val=val)) doc_sample = [eg.reference for eg in subbatch] label_sample, _ = self._examples_to_truth(subbatch) self._require_labels() diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index d359b77db..2eda9deaf 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -360,6 +360,30 @@ def test_label_types(name): nlp.initialize() +@pytest.mark.parametrize( + "name,get_examples", + [ + ("textcat", make_get_examples_single_label), + ("textcat_multilabel", make_get_examples_multi_label), + ], +) +def test_invalid_label_value(name, get_examples): + nlp = Language() + textcat = nlp.add_pipe(name) + example_getter = get_examples(nlp) + + def invalid_examples(): + # make one example with an invalid score + examples = example_getter() + ref = examples[0].reference + key = list(ref.cats.keys())[0] + ref.cats[key] = 2.0 + return examples + + with pytest.raises(ValueError): + nlp.initialize(get_examples=invalid_examples) + + @pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"]) def test_no_label(name): nlp = Language()