Check category value validity through training

The _validate_categories is called in update, which for multilabel is
inherited from the single label component.
This commit is contained in:
Paul O'Leary McCann 2022-11-14 19:31:21 +09:00
parent 49baa1bffc
commit f51a63863d

View File

@ -192,12 +192,8 @@ class MultiLabel_TextCategorizer(TextCategorizer):
for label in labels:
self.add_label(label)
subbatch = list(islice(get_examples(), 10))
self._validate_categories(subbatch)
# check that annotation values are valid
for eg in subbatch:
for val in eg.reference.cats.values():
if not (val == 1.0 or val == 0.0):
raise ValueError(Errors.E851.format(val=val))
doc_sample = [eg.reference for eg in subbatch]
label_sample, _ = self._examples_to_truth(subbatch)
self._require_labels()
@ -208,4 +204,8 @@ class MultiLabel_TextCategorizer(TextCategorizer):
def _validate_categories(self, examples: Iterable[Example]):
"""This component allows any type of single- or multi-label annotations.
This method overwrites the more strict one from 'textcat'."""
pass
# check that annotation values are valid
for ex in examples:
for val in ex.reference.cats.values():
if not (val == 1.0 or val == 0.0):
raise ValueError(Errors.E851.format(val=val))