Extend existing tests to spancat_exclusive

In this commit, I extended the existing tests for spancat to include
spancat_exclusive. I parametrized the test functions with 'name'
(similar var name with textcat and textcat_multilabel) for each
applicable test.

TODO: Add overfitting tests for spancat_exclusive
This commit is contained in:
Lj Miranda 2022-11-18 13:56:07 +08:00
parent 9a35b24b48
commit 3315540896

View File

@ -41,38 +41,42 @@ def make_examples(nlp, data=TRAIN_DATA):
return train_examples return train_examples
def test_no_label(): @pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_no_label(name):
nlp = Language() nlp = Language()
nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
with pytest.raises(ValueError): with pytest.raises(ValueError):
nlp.initialize() nlp.initialize()
def test_no_resize(): @pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_no_resize(name):
nlp = Language() nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
spancat.add_label("Thing") spancat.add_label("Thing")
spancat.add_label("Phrase") spancat.add_label("Phrase")
assert spancat.labels == ("Thing", "Phrase") assert spancat.labels == ("Thing", "Phrase")
nlp.initialize() nlp.initialize()
assert spancat.model.get_dim("nO") == 2 assert spancat.model.get_dim("nO") == spancat._n_labels
# this throws an error because the spancat can't be resized after initialization # this throws an error because the spancat can't be resized after initialization
with pytest.raises(ValueError): with pytest.raises(ValueError):
spancat.add_label("Stuff") spancat.add_label("Stuff")
def test_implicit_labels(): @pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_implicit_labels(name):
nlp = Language() nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
assert len(spancat.labels) == 0 assert len(spancat.labels) == 0
train_examples = make_examples(nlp) train_examples = make_examples(nlp)
nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
assert spancat.labels == ("PERSON", "LOC") assert spancat.labels == ("PERSON", "LOC")
def test_explicit_labels(): @pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_explicit_labels(name):
nlp = Language() nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
assert len(spancat.labels) == 0 assert len(spancat.labels) == 0
spancat.add_label("PERSON") spancat.add_label("PERSON")
spancat.add_label("LOC") spancat.add_label("LOC")
@ -371,7 +375,8 @@ def test_overfitting_IO_overlapping():
assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"} assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}
def test_zero_suggestions(): @pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_zero_suggestions(name):
# Test with a suggester that returns 0 suggestions # Test with a suggester that returns 0 suggestions
@registry.misc("test_zero_suggester") @registry.misc("test_zero_suggester")
@ -388,20 +393,21 @@ def test_zero_suggestions():
fix_random_seed(0) fix_random_seed(0)
nlp = English() nlp = English()
spancat = nlp.add_pipe( spancat = nlp.add_pipe(
"spancat", name,
config={"suggester": {"@misc": "test_zero_suggester"}, "spans_key": SPAN_KEY}, config={"suggester": {"@misc": "test_zero_suggester"}, "spans_key": SPAN_KEY},
) )
train_examples = make_examples(nlp) train_examples = make_examples(nlp)
optimizer = nlp.initialize(get_examples=lambda: train_examples) optimizer = nlp.initialize(get_examples=lambda: train_examples)
assert spancat.model.get_dim("nO") == 2 assert spancat.model.get_dim("nO") == spancat._n_labels
assert set(spancat.labels) == {"LOC", "PERSON"} assert set(spancat.labels) == {"LOC", "PERSON"}
nlp.update(train_examples, sgd=optimizer) nlp.update(train_examples, sgd=optimizer)
def test_set_candidates(): @pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"])
def test_set_candidates(name):
nlp = Language() nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
train_examples = make_examples(nlp) train_examples = make_examples(nlp)
nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
texts = [ texts = [