From f7c6cf80be5b20545f8086661e7df73847319b9a Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 29 Nov 2022 15:23:39 +0100 Subject: [PATCH] Restrict cats_score to provided labels --- spacy/scorer.py | 14 ++++----- spacy/tests/test_language.py | 59 ++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/spacy/scorer.py b/spacy/scorer.py index 16fc303a0..1059685fc 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -473,17 +473,17 @@ class Scorer: threshold = 0.5 if multi_label else 0.0 if not multi_label: threshold = 0.0 + labels = set(labels) f_per_type = {label: PRFScore() for label in labels} auc_per_type = {label: ROCAUCScore() for label in labels} - labels = set(labels) - if labels: - for eg in examples: - labels.update(eg.predicted.cats.keys()) - labels.update(eg.reference.cats.keys()) for example in examples: # Through this loop, None in the gold_cats indicates missing label. - pred_cats = getter(example.predicted, attr) - gold_cats = getter(example.reference, attr) + pred_cats = { + k: v for k, v in getter(example.predicted, attr).items() if k in labels + } + gold_cats = { + k: v for k, v in getter(example.reference, attr).items() if k in labels + } for label in labels: pred_score = pred_cats.get(label, 0.0) diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 03a98d32f..5e89fff51 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -58,6 +58,29 @@ def nlp(): return nlp +@pytest.fixture +def nlp_tcm(): + nlp = Language(Vocab()) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + return nlp + + +@pytest.fixture +def nlp_tc_tcm(): + nlp = Language(Vocab()) + textcat = nlp.add_pipe("textcat") + for label in ("POSITIVE", "NEGATIVE"): + textcat.add_label(label) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + return nlp + + def test_language_update(nlp): text = "hello world" annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}} @@ -126,6 +149,42 @@ def test_evaluate_no_pipe(nlp): nlp.evaluate([Example.from_dict(doc, annots)]) +def test_evaluate_textcat_multilabel(nlp_tcm): + """Test that evaluate works with a multilabel textcat pipe.""" + text = "hello world" + annots = {"doc_annotation": {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}}} + doc = Doc(nlp_tcm.vocab, words=text.split(" ")) + example = Example.from_dict(doc, annots) + scores = nlp_tcm.evaluate([example]) + labels = nlp_tcm.get_pipe("textcat_multilabel").labels + for label in labels: + assert scores["cats_f_per_type"].get(label) is not None + for key in example.reference.cats.keys(): + if key not in labels: + assert scores["cats_f_per_type"].get(key) is None + + +def test_evaluate_multiple_textcat(nlp_tc_tcm): + """Test that evaluate evaluates the final textcat component in a pipeline + with more than one textcat or textcat_multilabel.""" + text = "hello world" + annots = { + "doc_annotation": { + "cats": {"FEATURE": 1.0, "QUESTION": 1.0, "POSITIVE": 1.0, "NEGATIVE": 0.0} + } + } + doc = Doc(nlp_tc_tcm.vocab, words=text.split(" ")) + example = Example.from_dict(doc, annots) + scores = nlp_tc_tcm.evaluate([example]) + # get the labels from the final pipe + labels = nlp_tc_tcm.get_pipe(nlp_tc_tcm.pipe_names[-1]).labels + for label in labels: + assert scores["cats_f_per_type"].get(label) is not None + for key in example.reference.cats.keys(): + if key not in labels: + assert scores["cats_f_per_type"].get(key) is None + + def vector_modification_pipe(doc): doc.vector += 1 return doc