From 6d03b04901e95a71747a7e1ef0b00bc87bb2c807 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Mon, 9 Jan 2023 11:43:48 +0100 Subject: [PATCH] Improve score_cats for use with multiple textcat components (#11820) * add test for running evaluate on an nlp pipeline with two distinct textcat components * cleanup * merge dicts instead of overwrite * don't add more labels to the given set * Revert "merge dicts instead of overwrite" This reverts commit 89bee0ed7798389e6de882a0234e6075fbdaf331. * Switch tests to separate scorer keys rather than merged dicts * Revert unrelated edits * Switch textcat scorers to v2 * formatting Co-authored-by: Adriane Boyd --- spacy/pipeline/textcat_multilabel.py | 4 +- spacy/scorer.py | 6 +- spacy/tests/pipeline/test_textcat.py | 6 +- spacy/tests/test_language.py | 107 +++++++++++++++++++++++++++ 4 files changed, 116 insertions(+), 7 deletions(-) diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index 328cee723..41c0e2f63 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -74,7 +74,7 @@ subword_features = true default_config={ "threshold": 0.5, "model": DEFAULT_MULTI_TEXTCAT_MODEL, - "scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"}, + "scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v2"}, }, default_score_weights={ "cats_score": 1.0, @@ -120,7 +120,7 @@ def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str, ) -@registry.scorers("spacy.textcat_multilabel_scorer.v1") +@registry.scorers("spacy.textcat_multilabel_scorer.v2") def make_textcat_multilabel_scorer(): return textcat_multilabel_score diff --git a/spacy/scorer.py b/spacy/scorer.py index 16fc303a0..d8c383ab8 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -476,14 +476,12 @@ class Scorer: f_per_type = {label: PRFScore() for label in labels} auc_per_type = {label: ROCAUCScore() for label in labels} labels = set(labels) - if labels: - for eg in examples: - labels.update(eg.predicted.cats.keys()) - labels.update(eg.reference.cats.keys()) for example in examples: # Through this loop, None in the gold_cats indicates missing label. pred_cats = getter(example.predicted, attr) + pred_cats = {k: v for k, v in pred_cats.items() if k in labels} gold_cats = getter(example.reference, attr) + gold_cats = {k: v for k, v in gold_cats.items() if k in labels} for label in labels: pred_score = pred_cats.get(label, 0.0) diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 048586cec..d042f3445 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -898,7 +898,11 @@ def test_textcat_multi_threshold(): @pytest.mark.parametrize( - "component_name,scorer", [("textcat", "spacy.textcat_scorer.v1")] + "component_name,scorer", + [ + ("textcat", "spacy.textcat_scorer.v1"), + ("textcat_multilabel", "spacy.textcat_multilabel_scorer.v1"), + ], ) def test_textcat_legacy_scorers(component_name, scorer): """Check that legacy scorers are registered and produce the expected score diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 03a98d32f..03790eb86 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -3,6 +3,7 @@ import logging from unittest import mock import pytest from spacy.language import Language +from spacy.scorer import Scorer from spacy.tokens import Doc, Span from spacy.vocab import Vocab from spacy.training import Example @@ -126,6 +127,112 @@ def test_evaluate_no_pipe(nlp): nlp.evaluate([Example.from_dict(doc, annots)]) +def test_evaluate_textcat_multilabel(en_vocab): + """Test that evaluate works with a multilabel textcat pipe.""" + nlp = Language(en_vocab) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + + annots = {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}} + doc = nlp.make_doc("hello world") + example = Example.from_dict(doc, annots) + scores = nlp.evaluate([example]) + labels = nlp.get_pipe("textcat_multilabel").labels + for label in labels: + assert scores["cats_f_per_type"].get(label) is not None + for key in example.reference.cats.keys(): + if key not in labels: + assert scores["cats_f_per_type"].get(key) is None + + +def test_evaluate_multiple_textcat_final(en_vocab): + """Test that evaluate evaluates the final textcat component in a pipeline + with more than one textcat or textcat_multilabel.""" + nlp = Language(en_vocab) + textcat = nlp.add_pipe("textcat") + for label in ("POSITIVE", "NEGATIVE"): + textcat.add_label(label) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + + annots = { + "cats": { + "POSITIVE": 1.0, + "NEGATIVE": 0.0, + "FEATURE": 1.0, + "QUESTION": 1.0, + "POSITIVE": 1.0, + "NEGATIVE": 0.0, + } + } + doc = nlp.make_doc("hello world") + example = Example.from_dict(doc, annots) + scores = nlp.evaluate([example]) + # get the labels from the final pipe + labels = nlp.get_pipe(nlp.pipe_names[-1]).labels + for label in labels: + assert scores["cats_f_per_type"].get(label) is not None + for key in example.reference.cats.keys(): + if key not in labels: + assert scores["cats_f_per_type"].get(key) is None + + +def test_evaluate_multiple_textcat_separate(en_vocab): + """Test that evaluate can evaluate multiple textcat components separately + with custom scorers.""" + + def custom_textcat_score(examples, **kwargs): + scores = Scorer.score_cats( + examples, + "cats", + multi_label=False, + **kwargs, + ) + return {f"custom_{k}": v for k, v in scores.items()} + + @spacy.registry.scorers("test_custom_textcat_scorer") + def make_custom_textcat_scorer(): + return custom_textcat_score + + nlp = Language(en_vocab) + textcat = nlp.add_pipe( + "textcat", + config={"scorer": {"@scorers": "test_custom_textcat_scorer"}}, + ) + for label in ("POSITIVE", "NEGATIVE"): + textcat.add_label(label) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + + annots = { + "cats": { + "POSITIVE": 1.0, + "NEGATIVE": 0.0, + "FEATURE": 1.0, + "QUESTION": 1.0, + "POSITIVE": 1.0, + "NEGATIVE": 0.0, + } + } + doc = nlp.make_doc("hello world") + example = Example.from_dict(doc, annots) + scores = nlp.evaluate([example]) + # check custom scores for the textcat pipe + assert "custom_cats_f_per_type" in scores + labels = nlp.get_pipe("textcat").labels + assert set(scores["custom_cats_f_per_type"].keys()) == set(labels) + # check default scores for the textcat_multilabel pipe + assert "cats_f_per_type" in scores + labels = nlp.get_pipe("textcat_multilabel").labels + assert set(scores["cats_f_per_type"].keys()) == set(labels) + + def vector_modification_pipe(doc): doc.vector += 1 return doc