From f51a63863ddaa0382b5e8ef705d228c4b0e7af30 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Mon, 14 Nov 2022 19:31:21 +0900 Subject: [PATCH] Check category value validity through training The _validate_categories is called in update, which for multilabel is inherited from the single label component. --- spacy/pipeline/textcat_multilabel.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index 78750dabc..ef9bd6557 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -192,12 +192,8 @@ class MultiLabel_TextCategorizer(TextCategorizer): for label in labels: self.add_label(label) subbatch = list(islice(get_examples(), 10)) + self._validate_categories(subbatch) - # check that annotation values are valid - for eg in subbatch: - for val in eg.reference.cats.values(): - if not (val == 1.0 or val == 0.0): - raise ValueError(Errors.E851.format(val=val)) doc_sample = [eg.reference for eg in subbatch] label_sample, _ = self._examples_to_truth(subbatch) self._require_labels() @@ -208,4 +204,8 @@ class MultiLabel_TextCategorizer(TextCategorizer): def _validate_categories(self, examples: Iterable[Example]): """This component allows any type of single- or multi-label annotations. This method overwrites the more strict one from 'textcat'.""" - pass + # check that annotation values are valid + for ex in examples: + for val in ex.reference.cats.values(): + if not (val == 1.0 or val == 0.0): + raise ValueError(Errors.E851.format(val=val))