From 3315540896662f95a76d1e33b95f30b5e65b9e5d Mon Sep 17 00:00:00 2001 From: Lj Miranda Date: Fri, 18 Nov 2022 13:56:07 +0800 Subject: [PATCH] Extend existing tests to spancat_exclusive In this commit, I extended the existing tests for spancat to include spancat_exclusive. I parametrized the test functions with 'name' (similar var name with textcat and textcat_multilabel) for each applicable test. TODO: Add overfitting tests for spancat_exclusive --- spacy/tests/pipeline/test_spancat.py | 34 ++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index 15256a763..7164dc3eb 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -41,38 +41,42 @@ def make_examples(nlp, data=TRAIN_DATA): return train_examples -def test_no_label(): +@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +def test_no_label(name): nlp = Language() - nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) with pytest.raises(ValueError): nlp.initialize() -def test_no_resize(): +@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +def test_no_resize(name): nlp = Language() - spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) spancat.add_label("Thing") spancat.add_label("Phrase") assert spancat.labels == ("Thing", "Phrase") nlp.initialize() - assert spancat.model.get_dim("nO") == 2 + assert spancat.model.get_dim("nO") == spancat._n_labels # this throws an error because the spancat can't be resized after initialization with pytest.raises(ValueError): spancat.add_label("Stuff") -def test_implicit_labels(): +@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +def test_implicit_labels(name): nlp = Language() - spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) assert len(spancat.labels) == 0 train_examples = make_examples(nlp) nlp.initialize(get_examples=lambda: train_examples) assert spancat.labels == ("PERSON", "LOC") -def test_explicit_labels(): +@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +def test_explicit_labels(name): nlp = Language() - spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) assert len(spancat.labels) == 0 spancat.add_label("PERSON") spancat.add_label("LOC") @@ -371,7 +375,8 @@ def test_overfitting_IO_overlapping(): assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"} -def test_zero_suggestions(): +@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +def test_zero_suggestions(name): # Test with a suggester that returns 0 suggestions @registry.misc("test_zero_suggester") @@ -388,20 +393,21 @@ def test_zero_suggestions(): fix_random_seed(0) nlp = English() spancat = nlp.add_pipe( - "spancat", + name, config={"suggester": {"@misc": "test_zero_suggester"}, "spans_key": SPAN_KEY}, ) train_examples = make_examples(nlp) optimizer = nlp.initialize(get_examples=lambda: train_examples) - assert spancat.model.get_dim("nO") == 2 + assert spancat.model.get_dim("nO") == spancat._n_labels assert set(spancat.labels) == {"LOC", "PERSON"} nlp.update(train_examples, sgd=optimizer) -def test_set_candidates(): +@pytest.mark.parametrize("name", ["spancat", "spancat_exclusive"]) +def test_set_candidates(name): nlp = Language() - spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY}) + spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY}) train_examples = make_examples(nlp) nlp.initialize(get_examples=lambda: train_examples) texts = [