mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 05:04:09 +03:00
Check textcat values for validity (#11763)
* Check textcat values for validity * Fix error numbers * Clean up vals reference * Check category value validity through training The _validate_categories is called in update, which for multilabel is inherited from the single label component. * Formatting
This commit is contained in:
parent
317b6ef99c
commit
75bb7ad541
|
@ -544,6 +544,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"during training, make sure to include it in 'annotating components'")
|
||||
|
||||
# New errors added in v3.x
|
||||
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
||||
"but found value of '{val}'.")
|
||||
E852 = ("The tar file pulled from the remote attempted an unsafe path "
|
||||
"traversal.")
|
||||
E853 = ("Unsupported component factory name '{name}'. The character '.' is "
|
||||
|
|
|
@ -293,7 +293,7 @@ class TextCategorizer(TrainablePipe):
|
|||
bp_scores(gradient)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
losses[self.name] += (gradient**2).sum()
|
||||
losses[self.name] += (gradient ** 2).sum()
|
||||
return losses
|
||||
|
||||
def _examples_to_truth(
|
||||
|
@ -327,7 +327,7 @@ class TextCategorizer(TrainablePipe):
|
|||
not_missing = self.model.ops.asarray(not_missing) # type: ignore
|
||||
d_scores = scores - truths
|
||||
d_scores *= not_missing
|
||||
mean_square_error = (d_scores**2).mean()
|
||||
mean_square_error = (d_scores ** 2).mean()
|
||||
return float(mean_square_error), d_scores
|
||||
|
||||
def add_label(self, label: str) -> int:
|
||||
|
@ -401,5 +401,9 @@ class TextCategorizer(TrainablePipe):
|
|||
def _validate_categories(self, examples: Iterable[Example]):
|
||||
"""Check whether the provided examples all have single-label cats annotations."""
|
||||
for ex in examples:
|
||||
if list(ex.reference.cats.values()).count(1.0) > 1:
|
||||
vals = list(ex.reference.cats.values())
|
||||
if vals.count(1.0) > 1:
|
||||
raise ValueError(Errors.E895.format(value=ex.reference.cats))
|
||||
for val in vals:
|
||||
if not (val == 1.0 or val == 0.0):
|
||||
raise ValueError(Errors.E851.format(val=val))
|
||||
|
|
|
@ -192,6 +192,8 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
|||
for label in labels:
|
||||
self.add_label(label)
|
||||
subbatch = list(islice(get_examples(), 10))
|
||||
self._validate_categories(subbatch)
|
||||
|
||||
doc_sample = [eg.reference for eg in subbatch]
|
||||
label_sample, _ = self._examples_to_truth(subbatch)
|
||||
self._require_labels()
|
||||
|
@ -202,4 +204,8 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
|||
def _validate_categories(self, examples: Iterable[Example]):
|
||||
"""This component allows any type of single- or multi-label annotations.
|
||||
This method overwrites the more strict one from 'textcat'."""
|
||||
pass
|
||||
# check that annotation values are valid
|
||||
for ex in examples:
|
||||
for val in ex.reference.cats.values():
|
||||
if not (val == 1.0 or val == 0.0):
|
||||
raise ValueError(Errors.E851.format(val=val))
|
||||
|
|
|
@ -360,6 +360,30 @@ def test_label_types(name):
|
|||
nlp.initialize()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name,get_examples",
|
||||
[
|
||||
("textcat", make_get_examples_single_label),
|
||||
("textcat_multilabel", make_get_examples_multi_label),
|
||||
],
|
||||
)
|
||||
def test_invalid_label_value(name, get_examples):
|
||||
nlp = Language()
|
||||
textcat = nlp.add_pipe(name)
|
||||
example_getter = get_examples(nlp)
|
||||
|
||||
def invalid_examples():
|
||||
# make one example with an invalid score
|
||||
examples = example_getter()
|
||||
ref = examples[0].reference
|
||||
key = list(ref.cats.keys())[0]
|
||||
ref.cats[key] = 2.0
|
||||
return examples
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
nlp.initialize(get_examples=invalid_examples)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])
|
||||
def test_no_label(name):
|
||||
nlp = Language()
|
||||
|
|
Loading…
Reference in New Issue
Block a user