Restrict cats_score to provided labels

This commit is contained in:
Adriane Boyd 2022-11-29 15:23:39 +01:00
parent 6f9d630f7e
commit f7c6cf80be
2 changed files with 66 additions and 7 deletions

View File

@ -473,17 +473,17 @@ class Scorer:
threshold = 0.5 if multi_label else 0.0 threshold = 0.5 if multi_label else 0.0
if not multi_label: if not multi_label:
threshold = 0.0 threshold = 0.0
labels = set(labels)
f_per_type = {label: PRFScore() for label in labels} f_per_type = {label: PRFScore() for label in labels}
auc_per_type = {label: ROCAUCScore() for label in labels} auc_per_type = {label: ROCAUCScore() for label in labels}
labels = set(labels)
if labels:
for eg in examples:
labels.update(eg.predicted.cats.keys())
labels.update(eg.reference.cats.keys())
for example in examples: for example in examples:
# Through this loop, None in the gold_cats indicates missing label. # Through this loop, None in the gold_cats indicates missing label.
pred_cats = getter(example.predicted, attr) pred_cats = {
gold_cats = getter(example.reference, attr) k: v for k, v in getter(example.predicted, attr).items() if k in labels
}
gold_cats = {
k: v for k, v in getter(example.reference, attr).items() if k in labels
}
for label in labels: for label in labels:
pred_score = pred_cats.get(label, 0.0) pred_score = pred_cats.get(label, 0.0)

View File

@ -58,6 +58,29 @@ def nlp():
return nlp return nlp
@pytest.fixture
def nlp_tcm():
nlp = Language(Vocab())
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
textcat_multilabel.add_label(label)
nlp.initialize()
return nlp
@pytest.fixture
def nlp_tc_tcm():
nlp = Language(Vocab())
textcat = nlp.add_pipe("textcat")
for label in ("POSITIVE", "NEGATIVE"):
textcat.add_label(label)
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
textcat_multilabel.add_label(label)
nlp.initialize()
return nlp
def test_language_update(nlp): def test_language_update(nlp):
text = "hello world" text = "hello world"
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}} annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
@ -126,6 +149,42 @@ def test_evaluate_no_pipe(nlp):
nlp.evaluate([Example.from_dict(doc, annots)]) nlp.evaluate([Example.from_dict(doc, annots)])
def test_evaluate_textcat_multilabel(nlp_tcm):
"""Test that evaluate works with a multilabel textcat pipe."""
text = "hello world"
annots = {"doc_annotation": {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}}}
doc = Doc(nlp_tcm.vocab, words=text.split(" "))
example = Example.from_dict(doc, annots)
scores = nlp_tcm.evaluate([example])
labels = nlp_tcm.get_pipe("textcat_multilabel").labels
for label in labels:
assert scores["cats_f_per_type"].get(label) is not None
for key in example.reference.cats.keys():
if key not in labels:
assert scores["cats_f_per_type"].get(key) is None
def test_evaluate_multiple_textcat(nlp_tc_tcm):
"""Test that evaluate evaluates the final textcat component in a pipeline
with more than one textcat or textcat_multilabel."""
text = "hello world"
annots = {
"doc_annotation": {
"cats": {"FEATURE": 1.0, "QUESTION": 1.0, "POSITIVE": 1.0, "NEGATIVE": 0.0}
}
}
doc = Doc(nlp_tc_tcm.vocab, words=text.split(" "))
example = Example.from_dict(doc, annots)
scores = nlp_tc_tcm.evaluate([example])
# get the labels from the final pipe
labels = nlp_tc_tcm.get_pipe(nlp_tc_tcm.pipe_names[-1]).labels
for label in labels:
assert scores["cats_f_per_type"].get(label) is not None
for key in example.reference.cats.keys():
if key not in labels:
assert scores["cats_f_per_type"].get(key) is None
def vector_modification_pipe(doc): def vector_modification_pipe(doc):
doc.vector += 1 doc.vector += 1
return doc return doc