tests for make_span_group with negative labels

This commit is contained in:
kadarakos 2023-02-10 14:06:07 +00:00
parent b98cba2bef
commit a281a7c9a1

View File

@ -126,10 +126,12 @@ def test_make_spangroup_multilabel(max_positive, nr_results):
indices = ngram_suggester([doc])[0].dataXd indices = ngram_suggester([doc])[0].dataXd
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]])) assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
labels = ["Thing", "City", "Person", "GreatCity"] labels = ["Thing", "City", "Person", "GreatCity"]
for label in labels:
spancat.add_label(label)
scores = numpy.asarray( scores = numpy.asarray(
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
) )
spangroup = spancat._make_span_group_multilabel(doc, indices, scores, labels) spangroup = spancat._make_span_group_multilabel(doc, indices, scores)
assert len(spangroup) == nr_results assert len(spangroup) == nr_results
# first span is always the second token "London" # first span is always the second token "London"
@ -159,7 +161,6 @@ def test_make_spangroup_multilabel(max_positive, nr_results):
assert spangroup[-1].label_ == "GreatCity" assert spangroup[-1].label_ == "GreatCity"
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5) assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"threshold,allow_overlap,nr_results", "threshold,allow_overlap,nr_results",
[(0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)], [(0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)],
@ -180,11 +181,13 @@ def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results):
indices = ngram_suggester([doc])[0].dataXd indices = ngram_suggester([doc])[0].dataXd
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]])) assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
labels = ["Thing", "City", "Person", "GreatCity"] labels = ["Thing", "City", "Person", "GreatCity"]
for label in labels:
spancat.add_label(label)
scores = numpy.asarray( scores = numpy.asarray(
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
) )
spangroup = spancat._make_span_group_singlelabel( spangroup = spancat._make_span_group_singlelabel(
doc, indices, scores, labels, allow_overlap doc, indices, scores, allow_overlap
) )
assert len(spangroup) == nr_results assert len(spangroup) == nr_results
if threshold > 0.4: if threshold > 0.4:
@ -209,6 +212,66 @@ def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results):
assert spangroup[0].text == "Greater London" assert spangroup[0].text == "Greater London"
def test_make_spangroup_negative_label():
fix_random_seed(0)
nlp_single = Language()
nlp_multi = Language()
spancat_single = nlp_single.add_pipe(
"spancat",
config={
"spans_key": SPAN_KEY,
"threshold": 0.1,
"max_positive": 1,
},
)
spancat_multi = nlp_multi.add_pipe(
"spancat",
config={
"spans_key": SPAN_KEY,
"threshold": 0.1,
"max_positive": 2,
},
)
spancat_single.add_negative_label = True
spancat_multi.add_negative_label = True
doc = nlp_single.make_doc("Greater London")
labels = ["Thing", "City", "Person", "GreatCity"]
for label in labels:
spancat_multi.add_label(label)
spancat_single.add_label(label)
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
indices = ngram_suggester([doc])[0].dataXd
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
scores = numpy.asarray(
[[0.2, 0.4, 0.3, 0.1, 0.1], [0.1, 0.6, 0.2, 0.4, 0.9], [0.8, 0.7, 0.3, 0.9, 0.1]], dtype="f"
)
spangroup_multi = spancat_multi._make_span_group_multilabel(
doc, indices, scores
)
spangroup_single = spancat_single._make_span_group_singlelabel(
doc, indices, scores
)
assert len(spangroup_single) == 2
assert spangroup_single[0].text == "Greater"
assert spangroup_single[0].label_ == "City"
assert spangroup_single[1].text == "Greater London"
assert spangroup_single[1].label_ == "GreatCity"
assert len(spangroup_multi) == 6
assert spangroup_multi[0].text == "Greater"
assert spangroup_multi[0].label_ == "City"
assert spangroup_multi[1].text == "Greater"
assert spangroup_multi[1].label_ == "Person"
assert spangroup_multi[2].text == "London"
assert spangroup_multi[2].label_ == "City"
assert spangroup_multi[3].text == "London"
assert spangroup_multi[3].label_ == "GreatCity"
assert spangroup_multi[4].text == "Greater London"
assert spangroup_multi[4].label_ == "Thing"
assert spangroup_multi[5].text == "Greater London"
assert spangroup_multi[5].label_ == "GreatCity"
def test_ngram_suggester(en_tokenizer): def test_ngram_suggester(en_tokenizer):
# test different n-gram lengths # test different n-gram lengths
for size in [1, 2, 3]: for size in [1, 2, 3]: