mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-07 13:44:55 +03:00
Check category value validity through training
The _validate_categories is called in update, which for multilabel is inherited from the single label component.
This commit is contained in:
parent
49baa1bffc
commit
f51a63863d
|
@ -192,12 +192,8 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
|||
for label in labels:
|
||||
self.add_label(label)
|
||||
subbatch = list(islice(get_examples(), 10))
|
||||
self._validate_categories(subbatch)
|
||||
|
||||
# check that annotation values are valid
|
||||
for eg in subbatch:
|
||||
for val in eg.reference.cats.values():
|
||||
if not (val == 1.0 or val == 0.0):
|
||||
raise ValueError(Errors.E851.format(val=val))
|
||||
doc_sample = [eg.reference for eg in subbatch]
|
||||
label_sample, _ = self._examples_to_truth(subbatch)
|
||||
self._require_labels()
|
||||
|
@ -208,4 +204,8 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
|||
def _validate_categories(self, examples: Iterable[Example]):
|
||||
"""This component allows any type of single- or multi-label annotations.
|
||||
This method overwrites the more strict one from 'textcat'."""
|
||||
pass
|
||||
# check that annotation values are valid
|
||||
for ex in examples:
|
||||
for val in ex.reference.cats.values():
|
||||
if not (val == 1.0 or val == 0.0):
|
||||
raise ValueError(Errors.E851.format(val=val))
|
||||
|
|
Loading…
Reference in New Issue
Block a user