diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index fc9229867..03790eb86 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -3,6 +3,7 @@ import logging from unittest import mock import pytest from spacy.language import Language +from spacy.scorer import Scorer from spacy.tokens import Doc, Span from spacy.vocab import Vocab from spacy.training import Example @@ -58,29 +59,6 @@ def nlp(): return nlp -@pytest.fixture -def nlp_multi(): - nlp = Language(Vocab()) - textcat_multilabel = nlp.add_pipe("textcat_multilabel") - for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): - textcat_multilabel.add_label(label) - nlp.initialize() - return nlp - - -@pytest.fixture -def nlp_both(): - nlp = Language(Vocab()) - textcat = nlp.add_pipe("textcat") - for label in ("POSITIVE", "NEGATIVE"): - textcat.add_label(label) - textcat_multilabel = nlp.add_pipe("textcat_multilabel") - for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): - textcat_multilabel.add_label(label) - nlp.initialize() - return nlp - - def test_language_update(nlp): text = "hello world" annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}} @@ -114,9 +92,6 @@ def test_language_evaluate(nlp): example = Example.from_dict(doc, annots) scores = nlp.evaluate([example]) assert scores["speed"] > 0 - assert scores["cats_f_per_type"].get("POSITIVE") is not None - assert scores["cats_f_per_type"].get("NEGATIVE") is not None - assert scores["cats_f_per_type"].get("BUG") is None # test with generator scores = nlp.evaluate(eg for eg in [example]) @@ -152,33 +127,110 @@ def test_evaluate_no_pipe(nlp): nlp.evaluate([Example.from_dict(doc, annots)]) -def test_evaluate_textcat(nlp_multi): +def test_evaluate_textcat_multilabel(en_vocab): """Test that evaluate works with a multilabel textcat pipe.""" - text = "hello world" - annots = {"doc_annotation": {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}}} - doc = Doc(nlp_multi.vocab, words=text.split(" ")) + nlp = Language(en_vocab) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + + annots = {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}} + doc = nlp.make_doc("hello world") example = Example.from_dict(doc, annots) - scores = nlp_multi.evaluate([example]) - assert scores["cats_f_per_type"].get("FEATURE") is not None - assert scores["cats_f_per_type"].get("QUESTION") is not None - assert scores["cats_f_per_type"].get("REQUEST") is not None - assert scores["cats_f_per_type"].get("BUG") is not None - assert scores["cats_f_per_type"].get("POSITIVE") is None - assert scores["cats_f_per_type"].get("NEGATIVE") is None + scores = nlp.evaluate([example]) + labels = nlp.get_pipe("textcat_multilabel").labels + for label in labels: + assert scores["cats_f_per_type"].get(label) is not None + for key in example.reference.cats.keys(): + if key not in labels: + assert scores["cats_f_per_type"].get(key) is None -def test_evaluate_both(nlp_both): - """Test that evaluate works with two textcat pipes.""" - text = "hello world" - annots = {"doc_annotation": {"cats": {"FEATURE": 1.0, "QUESTION": 1.0, "POSITIVE": 1.0, "NEGATIVE": 0.0}}} - doc = Doc(nlp_both.vocab, words=text.split(" ")) +def test_evaluate_multiple_textcat_final(en_vocab): + """Test that evaluate evaluates the final textcat component in a pipeline + with more than one textcat or textcat_multilabel.""" + nlp = Language(en_vocab) + textcat = nlp.add_pipe("textcat") + for label in ("POSITIVE", "NEGATIVE"): + textcat.add_label(label) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + + annots = { + "cats": { + "POSITIVE": 1.0, + "NEGATIVE": 0.0, + "FEATURE": 1.0, + "QUESTION": 1.0, + "POSITIVE": 1.0, + "NEGATIVE": 0.0, + } + } + doc = nlp.make_doc("hello world") example = Example.from_dict(doc, annots) - scores = nlp_both.evaluate([example]) - assert scores["cats_f_per_type"].get("FEATURE") is not None - assert scores["cats_f_per_type"].get("QUESTION") is not None - assert scores["cats_f_per_type"].get("BUG") is not None - assert scores["cats_f_per_type"].get("POSITIVE") is not None - assert scores["cats_f_per_type"].get("NEGATIVE") is not None + scores = nlp.evaluate([example]) + # get the labels from the final pipe + labels = nlp.get_pipe(nlp.pipe_names[-1]).labels + for label in labels: + assert scores["cats_f_per_type"].get(label) is not None + for key in example.reference.cats.keys(): + if key not in labels: + assert scores["cats_f_per_type"].get(key) is None + + +def test_evaluate_multiple_textcat_separate(en_vocab): + """Test that evaluate can evaluate multiple textcat components separately + with custom scorers.""" + + def custom_textcat_score(examples, **kwargs): + scores = Scorer.score_cats( + examples, + "cats", + multi_label=False, + **kwargs, + ) + return {f"custom_{k}": v for k, v in scores.items()} + + @spacy.registry.scorers("test_custom_textcat_scorer") + def make_custom_textcat_scorer(): + return custom_textcat_score + + nlp = Language(en_vocab) + textcat = nlp.add_pipe( + "textcat", + config={"scorer": {"@scorers": "test_custom_textcat_scorer"}}, + ) + for label in ("POSITIVE", "NEGATIVE"): + textcat.add_label(label) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + + annots = { + "cats": { + "POSITIVE": 1.0, + "NEGATIVE": 0.0, + "FEATURE": 1.0, + "QUESTION": 1.0, + "POSITIVE": 1.0, + "NEGATIVE": 0.0, + } + } + doc = nlp.make_doc("hello world") + example = Example.from_dict(doc, annots) + scores = nlp.evaluate([example]) + # check custom scores for the textcat pipe + assert "custom_cats_f_per_type" in scores + labels = nlp.get_pipe("textcat").labels + assert set(scores["custom_cats_f_per_type"].keys()) == set(labels) + # check default scores for the textcat_multilabel pipe + assert "cats_f_per_type" in scores + labels = nlp.get_pipe("textcat_multilabel").labels + assert set(scores["cats_f_per_type"].keys()) == set(labels) def vector_modification_pipe(doc):