Extend existing tests to spancat_exclusive

In this commit, I extended the existing tests for spancat to include
spancat_exclusive. I parametrized the test functions with 'name'
(similar var name with textcat and textcat_multilabel) for each
applicable test.

TODO: Add overfitting tests for spancat_exclusive
This commit is contained in:
Lj Miranda 2022-11-18 13:56:07 +08:00
parent 9a35b24b48
commit 3315540896

View File

@ -41,38 +41,42 @@ def make_examples(nlp, data=TRAIN_DATA):
return train_examples
def test_no_label():
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_no_label(name):
nlp = Language()
nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
with pytest.raises(ValueError):
nlp.initialize()
def test_no_resize():
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_no_resize(name):
nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
spancat.add_label("Thing")
spancat.add_label("Phrase")
assert spancat.labels == ("Thing", "Phrase")
nlp.initialize()
assert spancat.model.get_dim("nO") == 2
assert spancat.model.get_dim("nO") == spancat._n_labels
# this throws an error because the spancat can't be resized after initialization
with pytest.raises(ValueError):
spancat.add_label("Stuff")
def test_implicit_labels():
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_implicit_labels(name):
nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
assert len(spancat.labels) == 0
train_examples = make_examples(nlp)
nlp.initialize(get_examples=lambda: train_examples)
assert spancat.labels == ("PERSON", "LOC")
def test_explicit_labels():
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_explicit_labels(name):
nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
assert len(spancat.labels) == 0
spancat.add_label("PERSON")
spancat.add_label("LOC")
@ -371,7 +375,8 @@ def test_overfitting_IO_overlapping():
assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}
def test_zero_suggestions():
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_zero_suggestions(name):
# Test with a suggester that returns 0 suggestions
@registry.misc("test_zero_suggester")
@ -388,20 +393,21 @@ def test_zero_suggestions():
fix_random_seed(0)
nlp = English()
spancat = nlp.add_pipe(
"spancat",
name,
config={"suggester": {"@misc": "test_zero_suggester"}, "spans_key": SPAN_KEY},
)
train_examples = make_examples(nlp)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
assert spancat.model.get_dim("nO") == 2
assert spancat.model.get_dim("nO") == spancat._n_labels
assert set(spancat.labels) == {"LOC", "PERSON"}
nlp.update(train_examples, sgd=optimizer)
def test_set_candidates():
@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_set_candidates(name):
nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
train_examples = make_examples(nlp)
nlp.initialize(get_examples=lambda: train_examples)
texts = [