Check textcat values for validity

This commit is contained in:
Paul O'Leary McCann 2022-11-07 19:01:30 +09:00
parent b76222e56a
commit 0289676177
4 changed files with 35 additions and 0 deletions

View File

@ -544,6 +544,8 @@ class Errors(metaclass=ErrorsWithCodes):
"during training, make sure to include it in 'annotating components'")
# New errors added in v3.x
E852 = ("The 'textcat' component labels should only have values of 0 or 1, "
"but found value of '{val}'.")
E853 = ("Unsupported component factory name '{name}'. The character '.' is "
"not permitted in factory names.")
E854 = ("Unable to set doc.ents. Check that the 'ents_filter' does not "

View File

@ -403,3 +403,6 @@ class TextCategorizer(TrainablePipe):
for ex in examples:
if list(ex.reference.cats.values()).count(1.0) > 1:
raise ValueError(Errors.E895.format(value=ex.reference.cats))
for val in ex.reference.cats.values():
if not (val == 1.0 or val == 0.0):
raise ValueError(Errors.E852.format(val=val))

View File

@ -192,6 +192,12 @@ class MultiLabel_TextCategorizer(TextCategorizer):
for label in labels:
self.add_label(label)
subbatch = list(islice(get_examples(), 10))
# check that annotation values are valid
for eg in subbatch:
for val in eg.reference.cats.values():
if not (val == 1.0 or val == 0.0):
raise ValueError(Errors.E852.format(val=val))
doc_sample = [eg.reference for eg in subbatch]
label_sample, _ = self._examples_to_truth(subbatch)
self._require_labels()

View File

@ -360,6 +360,30 @@ def test_label_types(name):
nlp.initialize()
@pytest.mark.parametrize(
"name,get_examples",
[
("textcat", make_get_examples_single_label),
("textcat_multilabel", make_get_examples_multi_label),
],
)
def test_invalid_label_value(name, get_examples):
nlp = Language()
textcat = nlp.add_pipe(name)
example_getter = get_examples(nlp)
def invalid_examples():
# make one example with an invalid score
examples = example_getter()
ref = examples[0].reference
key = list(ref.cats.keys())[0]
ref.cats[key] = 2.0
return examples
with pytest.raises(ValueError):
nlp.initialize(get_examples=invalid_examples)
@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])
def test_no_label(name):
nlp = Language()