diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index c4615f793..99128cfa6 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -108,13 +108,13 @@ def test_doc_gc(): # XXX This fails with length 0 sometimes assert len(spangroup) > 0 with pytest.raises(RuntimeError): - span = spangroup[0] + spangroup[0] @pytest.mark.parametrize( "max_positive,nr_results", [(None, 4), (1, 2), (2, 3), (3, 4), (4, 4)] ) -def test_make_spangroup(max_positive, nr_results): +def test_make_spangroup_multilabel(max_positive, nr_results): fix_random_seed(0) nlp = Language() spancat = nlp.add_pipe( @@ -160,6 +160,55 @@ def test_make_spangroup(max_positive, nr_results): assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5) +@pytest.mark.parametrize( + "threshold,allow_overlap,nr_results", [ + (0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)] +) +def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results): + fix_random_seed(0) + nlp = Language() + spancat = nlp.add_pipe( + "spancat", + config={ + "spans_key": SPAN_KEY, + "threshold": threshold, + "max_positive": 1, + }, + ) + doc = nlp.make_doc("Greater London") + ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2]) + indices = ngram_suggester([doc])[0].dataXd + assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]])) + labels = ["Thing", "City", "Person", "GreatCity"] + scores = numpy.asarray( + [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" + ) + spangroup = spancat._make_span_group_singlelabel( + doc, indices, scores, labels, allow_overlap + ) + assert len(spangroup) == nr_results + if threshold > 0.4: + if allow_overlap: + assert spangroup[0].text == "London" + assert spangroup[0].label_ == "City" + assert spangroup[1].text == "Greater London" + assert spangroup[1].label_ == "GreatCity" + + else: + assert spangroup[0].text == "Greater London" + assert spangroup[0].label_ == "GreatCity" + else: + if allow_overlap: + assert spangroup[0].text == "Greater" + assert spangroup[0].label_ == "City" + assert spangroup[1].text == "London" + assert spangroup[1].label_ == "City" + assert spangroup[2].text == "Greater London" + assert spangroup[2].label_ == "GreatCity" + else: + assert spangroup[0].text == "Greater London" + + def test_ngram_suggester(en_tokenizer): # test different n-gram lengths for size in [1, 2, 3]: