This commit is contained in:
kadarakos 2023-02-10 14:07:39 +00:00
parent a07aafc28e
commit afc3a5a4af

View File

@ -161,6 +161,7 @@ def test_make_spangroup_multilabel(max_positive, nr_results):
assert spangroup[-1].label_ == "GreatCity"
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
@pytest.mark.parametrize(
"threshold,allow_overlap,nr_results",
[(0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)],
@ -243,14 +244,15 @@ def test_make_spangroup_negative_label():
indices = ngram_suggester([doc])[0].dataXd
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
scores = numpy.asarray(
[[0.2, 0.4, 0.3, 0.1, 0.1], [0.1, 0.6, 0.2, 0.4, 0.9], [0.8, 0.7, 0.3, 0.9, 0.1]], dtype="f"
)
spangroup_multi = spancat_multi._make_span_group_multilabel(
doc, indices, scores
)
spangroup_single = spancat_single._make_span_group_singlelabel(
doc, indices, scores
[
[0.2, 0.4, 0.3, 0.1, 0.1],
[0.1, 0.6, 0.2, 0.4, 0.9],
[0.8, 0.7, 0.3, 0.9, 0.1],
],
dtype="f",
)
spangroup_multi = spancat_multi._make_span_group_multilabel(doc, indices, scores)
spangroup_single = spancat_single._make_span_group_singlelabel(doc, indices, scores)
assert len(spangroup_single) == 2
assert spangroup_single[0].text == "Greater"
assert spangroup_single[0].label_ == "City"