diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index 2484b5ff5..2694c01f4 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -161,6 +161,7 @@ def test_make_spangroup_multilabel(max_positive, nr_results): assert spangroup[-1].label_ == "GreatCity" assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5) + @pytest.mark.parametrize( "threshold,allow_overlap,nr_results", [(0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)], @@ -243,14 +244,15 @@ def test_make_spangroup_negative_label(): indices = ngram_suggester([doc])[0].dataXd assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]])) scores = numpy.asarray( - [[0.2, 0.4, 0.3, 0.1, 0.1], [0.1, 0.6, 0.2, 0.4, 0.9], [0.8, 0.7, 0.3, 0.9, 0.1]], dtype="f" - ) - spangroup_multi = spancat_multi._make_span_group_multilabel( - doc, indices, scores - ) - spangroup_single = spancat_single._make_span_group_singlelabel( - doc, indices, scores + [ + [0.2, 0.4, 0.3, 0.1, 0.1], + [0.1, 0.6, 0.2, 0.4, 0.9], + [0.8, 0.7, 0.3, 0.9, 0.1], + ], + dtype="f", ) + spangroup_multi = spancat_multi._make_span_group_multilabel(doc, indices, scores) + spangroup_single = spancat_single._make_span_group_singlelabel(doc, indices, scores) assert len(spangroup_single) == 2 assert spangroup_single[0].text == "Greater" assert spangroup_single[0].label_ == "City"