From ae919e990798dbc5a32d5bcea3f290da965fe8ad Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 13 Dec 2022 14:37:13 +0100 Subject: [PATCH] Move pipeline definitions into tests --- spacy/tests/test_language.py | 56 +++++++++++++++--------------------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index c4d968582..d3246224e 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -58,29 +58,6 @@ def nlp(): return nlp -@pytest.fixture -def nlp_tcm(): - nlp = Language(Vocab()) - textcat_multilabel = nlp.add_pipe("textcat_multilabel") - for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): - textcat_multilabel.add_label(label) - nlp.initialize() - return nlp - - -@pytest.fixture -def nlp_tc_tcm(): - nlp = Language(Vocab()) - textcat = nlp.add_pipe("textcat") - for label in ("POSITIVE", "NEGATIVE"): - textcat.add_label(label) - textcat_multilabel = nlp.add_pipe("textcat_multilabel") - for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): - textcat_multilabel.add_label(label) - nlp.initialize() - return nlp - - def test_language_update(nlp): text = "hello world" annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}} @@ -149,14 +126,19 @@ def test_evaluate_no_pipe(nlp): nlp.evaluate([Example.from_dict(doc, annots)]) -def test_evaluate_textcat_multilabel(nlp_tcm): +def test_evaluate_textcat_multilabel(en_vocab): """Test that evaluate works with a multilabel textcat pipe.""" - text = "hello world" + nlp = Language(en_vocab) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + annots = {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}} - doc = Doc(nlp_tcm.vocab, words=text.split(" ")) + doc = nlp.make_doc("hello world") example = Example.from_dict(doc, annots) - scores = nlp_tcm.evaluate([example]) - labels = nlp_tcm.get_pipe("textcat_multilabel").labels + scores = nlp.evaluate([example]) + labels = nlp.get_pipe("textcat_multilabel").labels for label in labels: assert scores["cats_f_per_type"].get(label) is not None for key in example.reference.cats.keys(): @@ -164,10 +146,18 @@ def test_evaluate_textcat_multilabel(nlp_tcm): assert scores["cats_f_per_type"].get(key) is None -def test_evaluate_multiple_textcat(nlp_tc_tcm): +def test_evaluate_multiple_textcat_final(en_vocab): """Test that evaluate evaluates the final textcat component in a pipeline with more than one textcat or textcat_multilabel.""" - text = "hello world" + nlp = Language(en_vocab) + textcat = nlp.add_pipe("textcat") + for label in ("POSITIVE", "NEGATIVE"): + textcat.add_label(label) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + annots = { "cats": { "POSITIVE": 1.0, @@ -178,11 +168,11 @@ def test_evaluate_multiple_textcat(nlp_tc_tcm): "NEGATIVE": 0.0, } } - doc = Doc(nlp_tc_tcm.vocab, words=text.split(" ")) + doc = nlp.make_doc("hello world") example = Example.from_dict(doc, annots) - scores = nlp_tc_tcm.evaluate([example]) + scores = nlp.evaluate([example]) # get the labels from the final pipe - labels = nlp_tc_tcm.get_pipe(nlp_tc_tcm.pipe_names[-1]).labels + labels = nlp.get_pipe(nlp.pipe_names[-1]).labels for label in labels: assert scores["cats_f_per_type"].get(label) is not None for key in example.reference.cats.keys():