diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index 7164dc3eb..fa3b76f30 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -15,6 +15,8 @@ OPS = get_current_ops() SPAN_KEY = "labeled_spans" +SPANCAT_COMPONENTS = ["spancat", "spancat_exclusive"] + TRAIN_DATA = [ ("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}), ( @@ -41,7 +43,7 @@ def make_examples(nlp, data=TRAIN_DATA): return train_examples -@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) def test_no_label(name): nlp = Language() nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) @@ -49,7 +51,7 @@ def test_no_label(name): nlp.initialize() -@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) def test_no_resize(name): nlp = Language() spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) @@ -63,7 +65,7 @@ def test_no_resize(name): spancat.add_label("Stuff") -@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) def test_implicit_labels(name): nlp = Language() spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) @@ -73,7 +75,7 @@ def test_implicit_labels(name): assert spancat.labels == ("PERSON", "LOC") -@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) def test_explicit_labels(name): nlp = Language() spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) @@ -375,7 +377,7 @@ def test_overfitting_IO_overlapping(): assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"} -@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) def test_zero_suggestions(name): # Test with a suggester that returns 0 suggestions @@ -404,7 +406,7 @@ def test_zero_suggestions(name): nlp.update(train_examples, sgd=optimizer) -@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +@pytest.mark.parametrize("name", SPANCAT_COMPONENTS) def test_set_candidates(name): nlp = Language() spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})