Check textcat values for validity (#11763)

* Check textcat values for validity

* Fix error numbers

* Clean up vals reference

* Check category value validity through training

The _validate_categories is called in update, which for multilabel is
inherited from the single label component.

* Formatting
This commit is contained in:
Paul O'Leary McCann 2022-11-17 18:25:01 +09:00 committed by GitHub
parent 317b6ef99c
commit 75bb7ad541
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 4 deletions

View File

@ -544,6 +544,8 @@ class Errors(metaclass=ErrorsWithCodes):
"during training, make sure to include it in 'annotating components'")
# New errors added in v3.x
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
"but found value of '{val}'.")
E852 = ("The tar file pulled from the remote attempted an unsafe path "
"traversal.")
E853 = ("Unsupported component factory name '{name}'. The character '.' is "

View File

@ -293,7 +293,7 @@ class TextCategorizer(TrainablePipe):
bp_scores(gradient)
if sgd is not None:
self.finish_update(sgd)
losses[self.name] += (gradient**2).sum()
losses[self.name] += (gradient ** 2).sum()
return losses
def _examples_to_truth(
@ -327,7 +327,7 @@ class TextCategorizer(TrainablePipe):
not_missing = self.model.ops.asarray(not_missing) # type: ignore
d_scores = scores - truths
d_scores *= not_missing
mean_square_error = (d_scores**2).mean()
mean_square_error = (d_scores ** 2).mean()
return float(mean_square_error), d_scores
def add_label(self, label: str) -> int:
@ -401,5 +401,9 @@ class TextCategorizer(TrainablePipe):
def _validate_categories(self, examples: Iterable[Example]):
"""Check whether the provided examples all have single-label cats annotations."""
for ex in examples:
if list(ex.reference.cats.values()).count(1.0) > 1:
vals = list(ex.reference.cats.values())
if vals.count(1.0) > 1:
raise ValueError(Errors.E895.format(value=ex.reference.cats))
for val in vals:
if not (val == 1.0 or val == 0.0):
raise ValueError(Errors.E851.format(val=val))

View File

@ -192,6 +192,8 @@ class MultiLabel_TextCategorizer(TextCategorizer):
for label in labels:
self.add_label(label)
subbatch = list(islice(get_examples(), 10))
self._validate_categories(subbatch)
doc_sample = [eg.reference for eg in subbatch]
label_sample, _ = self._examples_to_truth(subbatch)
self._require_labels()
@ -202,4 +204,8 @@ class MultiLabel_TextCategorizer(TextCategorizer):
def _validate_categories(self, examples: Iterable[Example]):
"""This component allows any type of single- or multi-label annotations.
This method overwrites the more strict one from 'textcat'."""
pass
# check that annotation values are valid
for ex in examples:
for val in ex.reference.cats.values():
if not (val == 1.0 or val == 0.0):
raise ValueError(Errors.E851.format(val=val))

View File

@ -360,6 +360,30 @@ def test_label_types(name):
nlp.initialize()
@pytest.mark.parametrize(
"name,get_examples",
[
("textcat", make_get_examples_single_label),
("textcat_multilabel", make_get_examples_multi_label),
],
)
def test_invalid_label_value(name, get_examples):
nlp = Language()
textcat = nlp.add_pipe(name)
example_getter = get_examples(nlp)
def invalid_examples():
# make one example with an invalid score
examples = example_getter()
ref = examples[0].reference
key = list(ref.cats.keys())[0]
ref.cats[key] = 2.0
return examples
with pytest.raises(ValueError):
nlp.initialize(get_examples=invalid_examples)
@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])
def test_no_label(name):
nlp = Language()