mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-07 13:44:55 +03:00
Check textcat values for validity
This commit is contained in:
parent
b76222e56a
commit
0289676177
|
@ -544,6 +544,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"during training, make sure to include it in 'annotating components'")
|
||||
|
||||
# New errors added in v3.x
|
||||
E852 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
||||
"but found value of '{val}'.")
|
||||
E853 = ("Unsupported component factory name '{name}'. The character '.' is "
|
||||
"not permitted in factory names.")
|
||||
E854 = ("Unable to set doc.ents. Check that the 'ents_filter' does not "
|
||||
|
|
|
@ -403,3 +403,6 @@ class TextCategorizer(TrainablePipe):
|
|||
for ex in examples:
|
||||
if list(ex.reference.cats.values()).count(1.0) > 1:
|
||||
raise ValueError(Errors.E895.format(value=ex.reference.cats))
|
||||
for val in ex.reference.cats.values():
|
||||
if not (val == 1.0 or val == 0.0):
|
||||
raise ValueError(Errors.E852.format(val=val))
|
||||
|
|
|
@ -192,6 +192,12 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
|||
for label in labels:
|
||||
self.add_label(label)
|
||||
subbatch = list(islice(get_examples(), 10))
|
||||
|
||||
# check that annotation values are valid
|
||||
for eg in subbatch:
|
||||
for val in eg.reference.cats.values():
|
||||
if not (val == 1.0 or val == 0.0):
|
||||
raise ValueError(Errors.E852.format(val=val))
|
||||
doc_sample = [eg.reference for eg in subbatch]
|
||||
label_sample, _ = self._examples_to_truth(subbatch)
|
||||
self._require_labels()
|
||||
|
|
|
@ -360,6 +360,30 @@ def test_label_types(name):
|
|||
nlp.initialize()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name,get_examples",
|
||||
[
|
||||
("textcat", make_get_examples_single_label),
|
||||
("textcat_multilabel", make_get_examples_multi_label),
|
||||
],
|
||||
)
|
||||
def test_invalid_label_value(name, get_examples):
|
||||
nlp = Language()
|
||||
textcat = nlp.add_pipe(name)
|
||||
example_getter = get_examples(nlp)
|
||||
|
||||
def invalid_examples():
|
||||
# make one example with an invalid score
|
||||
examples = example_getter()
|
||||
ref = examples[0].reference
|
||||
key = list(ref.cats.keys())[0]
|
||||
ref.cats[key] = 2.0
|
||||
return examples
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
nlp.initialize(get_examples=invalid_examples)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])
|
||||
def test_no_label(name):
|
||||
nlp = Language()
|
||||
|
|
Loading…
Reference in New Issue
Block a user