diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index c8c87f990..2484b5ff5 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -126,10 +126,12 @@ def test_make_spangroup_multilabel(max_positive, nr_results): indices = ngram_suggester([doc])[0].dataXd assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]])) labels = ["Thing", "City", "Person", "GreatCity"] + for label in labels: + spancat.add_label(label) scores = numpy.asarray( [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" ) - spangroup = spancat._make_span_group_multilabel(doc, indices, scores, labels) + spangroup = spancat._make_span_group_multilabel(doc, indices, scores) assert len(spangroup) == nr_results # first span is always the second token "London" @@ -159,7 +161,6 @@ def test_make_spangroup_multilabel(max_positive, nr_results): assert spangroup[-1].label_ == "GreatCity" assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5) - @pytest.mark.parametrize( "threshold,allow_overlap,nr_results", [(0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)], @@ -180,11 +181,13 @@ def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results): indices = ngram_suggester([doc])[0].dataXd assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]])) labels = ["Thing", "City", "Person", "GreatCity"] + for label in labels: + spancat.add_label(label) scores = numpy.asarray( [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" ) spangroup = spancat._make_span_group_singlelabel( - doc, indices, scores, labels, allow_overlap + doc, indices, scores, allow_overlap ) assert len(spangroup) == nr_results if threshold > 0.4: @@ -209,6 +212,66 @@ def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results): assert spangroup[0].text == "Greater London" +def test_make_spangroup_negative_label(): + fix_random_seed(0) + nlp_single = Language() + nlp_multi = Language() + spancat_single = nlp_single.add_pipe( + "spancat", + config={ + "spans_key": SPAN_KEY, + "threshold": 0.1, + "max_positive": 1, + }, + ) + spancat_multi = nlp_multi.add_pipe( + "spancat", + config={ + "spans_key": SPAN_KEY, + "threshold": 0.1, + "max_positive": 2, + }, + ) + spancat_single.add_negative_label = True + spancat_multi.add_negative_label = True + doc = nlp_single.make_doc("Greater London") + labels = ["Thing", "City", "Person", "GreatCity"] + for label in labels: + spancat_multi.add_label(label) + spancat_single.add_label(label) + ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2]) + indices = ngram_suggester([doc])[0].dataXd + assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]])) + scores = numpy.asarray( + [[0.2, 0.4, 0.3, 0.1, 0.1], [0.1, 0.6, 0.2, 0.4, 0.9], [0.8, 0.7, 0.3, 0.9, 0.1]], dtype="f" + ) + spangroup_multi = spancat_multi._make_span_group_multilabel( + doc, indices, scores + ) + spangroup_single = spancat_single._make_span_group_singlelabel( + doc, indices, scores + ) + assert len(spangroup_single) == 2 + assert spangroup_single[0].text == "Greater" + assert spangroup_single[0].label_ == "City" + assert spangroup_single[1].text == "Greater London" + assert spangroup_single[1].label_ == "GreatCity" + + assert len(spangroup_multi) == 6 + assert spangroup_multi[0].text == "Greater" + assert spangroup_multi[0].label_ == "City" + assert spangroup_multi[1].text == "Greater" + assert spangroup_multi[1].label_ == "Person" + assert spangroup_multi[2].text == "London" + assert spangroup_multi[2].label_ == "City" + assert spangroup_multi[3].text == "London" + assert spangroup_multi[3].label_ == "GreatCity" + assert spangroup_multi[4].text == "Greater London" + assert spangroup_multi[4].label_ == "Thing" + assert spangroup_multi[5].text == "Greater London" + assert spangroup_multi[5].label_ == "GreatCity" + + def test_ngram_suggester(en_tokenizer): # test different n-gram lengths for size in [1, 2, 3]: