Use a single variable for tests

This commit is contained in:
Lj Miranda 2022-11-29 11:25:35 +08:00
parent ac0ac3eb99
commit 4ab27d4517

View File

@ -15,6 +15,8 @@ OPS = get_current_ops()
SPAN_KEY = "labeled_spans"
SPANCAT_COMPONENTS = ["spancat", "spancat_exclusive"]
TRAIN_DATA = [
("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}),
(
@ -41,7 +43,7 @@ def make_examples(nlp, data=TRAIN_DATA):
return train_examples
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
def test_no_label(name):
nlp = Language()
nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
@ -49,7 +51,7 @@ def test_no_label(name):
nlp.initialize()
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
def test_no_resize(name):
nlp = Language()
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
@ -63,7 +65,7 @@ def test_no_resize(name):
spancat.add_label("Stuff")
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
def test_implicit_labels(name):
nlp = Language()
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
@ -73,7 +75,7 @@ def test_implicit_labels(name):
assert spancat.labels == ("PERSON", "LOC")
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
def test_explicit_labels(name):
nlp = Language()
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
@ -375,7 +377,7 @@ def test_overfitting_IO_overlapping():
assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
def test_zero_suggestions(name):
# Test with a suggester that returns 0 suggestions
@ -404,7 +406,7 @@ def test_zero_suggestions(name):
nlp.update(train_examples, sgd=optimizer)
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
def test_set_candidates(name):
nlp = Language()
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})