mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-30 18:03:04 +03:00
Check textcat values for validity (#11763)
* Check textcat values for validity * Fix error numbers * Clean up vals reference * Check category value validity through training The _validate_categories is called in update, which for multilabel is inherited from the single label component. * Formatting
This commit is contained in:
parent
317b6ef99c
commit
75bb7ad541
|
@ -544,6 +544,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
"during training, make sure to include it in 'annotating components'")
|
"during training, make sure to include it in 'annotating components'")
|
||||||
|
|
||||||
# New errors added in v3.x
|
# New errors added in v3.x
|
||||||
|
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
||||||
|
"but found value of '{val}'.")
|
||||||
E852 = ("The tar file pulled from the remote attempted an unsafe path "
|
E852 = ("The tar file pulled from the remote attempted an unsafe path "
|
||||||
"traversal.")
|
"traversal.")
|
||||||
E853 = ("Unsupported component factory name '{name}'. The character '.' is "
|
E853 = ("Unsupported component factory name '{name}'. The character '.' is "
|
||||||
|
|
|
@ -293,7 +293,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
bp_scores(gradient)
|
bp_scores(gradient)
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
losses[self.name] += (gradient**2).sum()
|
losses[self.name] += (gradient ** 2).sum()
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def _examples_to_truth(
|
def _examples_to_truth(
|
||||||
|
@ -327,7 +327,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
not_missing = self.model.ops.asarray(not_missing) # type: ignore
|
not_missing = self.model.ops.asarray(not_missing) # type: ignore
|
||||||
d_scores = scores - truths
|
d_scores = scores - truths
|
||||||
d_scores *= not_missing
|
d_scores *= not_missing
|
||||||
mean_square_error = (d_scores**2).mean()
|
mean_square_error = (d_scores ** 2).mean()
|
||||||
return float(mean_square_error), d_scores
|
return float(mean_square_error), d_scores
|
||||||
|
|
||||||
def add_label(self, label: str) -> int:
|
def add_label(self, label: str) -> int:
|
||||||
|
@ -401,5 +401,9 @@ class TextCategorizer(TrainablePipe):
|
||||||
def _validate_categories(self, examples: Iterable[Example]):
|
def _validate_categories(self, examples: Iterable[Example]):
|
||||||
"""Check whether the provided examples all have single-label cats annotations."""
|
"""Check whether the provided examples all have single-label cats annotations."""
|
||||||
for ex in examples:
|
for ex in examples:
|
||||||
if list(ex.reference.cats.values()).count(1.0) > 1:
|
vals = list(ex.reference.cats.values())
|
||||||
|
if vals.count(1.0) > 1:
|
||||||
raise ValueError(Errors.E895.format(value=ex.reference.cats))
|
raise ValueError(Errors.E895.format(value=ex.reference.cats))
|
||||||
|
for val in vals:
|
||||||
|
if not (val == 1.0 or val == 0.0):
|
||||||
|
raise ValueError(Errors.E851.format(val=val))
|
||||||
|
|
|
@ -192,6 +192,8 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
||||||
for label in labels:
|
for label in labels:
|
||||||
self.add_label(label)
|
self.add_label(label)
|
||||||
subbatch = list(islice(get_examples(), 10))
|
subbatch = list(islice(get_examples(), 10))
|
||||||
|
self._validate_categories(subbatch)
|
||||||
|
|
||||||
doc_sample = [eg.reference for eg in subbatch]
|
doc_sample = [eg.reference for eg in subbatch]
|
||||||
label_sample, _ = self._examples_to_truth(subbatch)
|
label_sample, _ = self._examples_to_truth(subbatch)
|
||||||
self._require_labels()
|
self._require_labels()
|
||||||
|
@ -202,4 +204,8 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
||||||
def _validate_categories(self, examples: Iterable[Example]):
|
def _validate_categories(self, examples: Iterable[Example]):
|
||||||
"""This component allows any type of single- or multi-label annotations.
|
"""This component allows any type of single- or multi-label annotations.
|
||||||
This method overwrites the more strict one from 'textcat'."""
|
This method overwrites the more strict one from 'textcat'."""
|
||||||
pass
|
# check that annotation values are valid
|
||||||
|
for ex in examples:
|
||||||
|
for val in ex.reference.cats.values():
|
||||||
|
if not (val == 1.0 or val == 0.0):
|
||||||
|
raise ValueError(Errors.E851.format(val=val))
|
||||||
|
|
|
@ -360,6 +360,30 @@ def test_label_types(name):
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name,get_examples",
|
||||||
|
[
|
||||||
|
("textcat", make_get_examples_single_label),
|
||||||
|
("textcat_multilabel", make_get_examples_multi_label),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_invalid_label_value(name, get_examples):
|
||||||
|
nlp = Language()
|
||||||
|
textcat = nlp.add_pipe(name)
|
||||||
|
example_getter = get_examples(nlp)
|
||||||
|
|
||||||
|
def invalid_examples():
|
||||||
|
# make one example with an invalid score
|
||||||
|
examples = example_getter()
|
||||||
|
ref = examples[0].reference
|
||||||
|
key = list(ref.cats.keys())[0]
|
||||||
|
ref.cats[key] = 2.0
|
||||||
|
return examples
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.initialize(get_examples=invalid_examples)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])
|
@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])
|
||||||
def test_no_label(name):
|
def test_no_label(name):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user