Fix spancat-singlelabel score (#12469)

* debug argmax sort and add span scores

* add missing tests for spanscores
This commit is contained in:
kadarakos 2023-03-29 08:38:11 +02:00 committed by GitHub
parent dba4e7bece
commit 372a90885e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions

View File

@ -726,6 +726,7 @@ class SpanCategorizer(TrainablePipe):
if not allow_overlap: if not allow_overlap:
# Get the probabilities # Get the probabilities
sort_idx = (argmax_scores.squeeze() * -1).argsort() sort_idx = (argmax_scores.squeeze() * -1).argsort()
argmax_scores = argmax_scores[sort_idx]
predicted = predicted[sort_idx] predicted = predicted[sort_idx]
indices = indices[sort_idx] indices = indices[sort_idx]
keeps = keeps[sort_idx] keeps = keeps[sort_idx]
@ -748,4 +749,5 @@ class SpanCategorizer(TrainablePipe):
attrs_scores.append(argmax_scores[i]) attrs_scores.append(argmax_scores[i])
spans.append(Span(doc, start, end, label=self.labels[label])) spans.append(Span(doc, start, end, label=self.labels[label]))
spans.attrs["scores"] = numpy.array(attrs_scores)
return spans return spans

View File

@ -190,17 +190,19 @@ def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results):
spangroup = spancat._make_span_group_singlelabel( spangroup = spancat._make_span_group_singlelabel(
doc, indices, scores, allow_overlap doc, indices, scores, allow_overlap
) )
assert len(spangroup) == nr_results
if threshold > 0.4: if threshold > 0.4:
if allow_overlap: if allow_overlap:
assert spangroup[0].text == "London" assert spangroup[0].text == "London"
assert spangroup[0].label_ == "City" assert spangroup[0].label_ == "City"
assert_almost_equal(0.6, spangroup.attrs["scores"][0], 5)
assert spangroup[1].text == "Greater London" assert spangroup[1].text == "Greater London"
assert spangroup[1].label_ == "GreatCity" assert spangroup[1].label_ == "GreatCity"
assert spangroup.attrs["scores"][1] == 0.9
assert_almost_equal(0.9, spangroup.attrs["scores"][1], 5)
else: else:
assert spangroup[0].text == "Greater London" assert spangroup[0].text == "Greater London"
assert spangroup[0].label_ == "GreatCity" assert spangroup[0].label_ == "GreatCity"
assert spangroup.attrs["scores"][0] == 0.9
else: else:
if allow_overlap: if allow_overlap:
assert spangroup[0].text == "Greater" assert spangroup[0].text == "Greater"
@ -256,22 +258,32 @@ def test_make_spangroup_negative_label():
assert len(spangroup_single) == 2 assert len(spangroup_single) == 2
assert spangroup_single[0].text == "Greater" assert spangroup_single[0].text == "Greater"
assert spangroup_single[0].label_ == "City" assert spangroup_single[0].label_ == "City"
assert_almost_equal(0.4, spangroup_single.attrs["scores"][0], 5)
assert spangroup_single[1].text == "Greater London" assert spangroup_single[1].text == "Greater London"
assert spangroup_single[1].label_ == "GreatCity" assert spangroup_single[1].label_ == "GreatCity"
assert spangroup_single.attrs["scores"][1] == 0.9
assert_almost_equal(0.9, spangroup_single.attrs["scores"][1], 5)
assert len(spangroup_multi) == 6 assert len(spangroup_multi) == 6
assert spangroup_multi[0].text == "Greater" assert spangroup_multi[0].text == "Greater"
assert spangroup_multi[0].label_ == "City" assert spangroup_multi[0].label_ == "City"
assert_almost_equal(0.4, spangroup_multi.attrs["scores"][0], 5)
assert spangroup_multi[1].text == "Greater" assert spangroup_multi[1].text == "Greater"
assert spangroup_multi[1].label_ == "Person" assert spangroup_multi[1].label_ == "Person"
assert_almost_equal(0.3, spangroup_multi.attrs["scores"][1], 5)
assert spangroup_multi[2].text == "London" assert spangroup_multi[2].text == "London"
assert spangroup_multi[2].label_ == "City" assert spangroup_multi[2].label_ == "City"
assert_almost_equal(0.6, spangroup_multi.attrs["scores"][2], 5)
assert spangroup_multi[3].text == "London" assert spangroup_multi[3].text == "London"
assert spangroup_multi[3].label_ == "GreatCity" assert spangroup_multi[3].label_ == "GreatCity"
assert_almost_equal(0.4, spangroup_multi.attrs["scores"][3], 5)
assert spangroup_multi[4].text == "Greater London" assert spangroup_multi[4].text == "Greater London"
assert spangroup_multi[4].label_ == "Thing" assert spangroup_multi[4].label_ == "Thing"
assert spangroup_multi[4].text == "Greater London"
assert_almost_equal(0.8, spangroup_multi.attrs["scores"][4], 5)
assert spangroup_multi[5].text == "Greater London" assert spangroup_multi[5].text == "Greater London"
assert spangroup_multi[5].label_ == "GreatCity" assert spangroup_multi[5].label_ == "GreatCity"
assert_almost_equal(0.9, spangroup_multi.attrs["scores"][5], 5)
def test_ngram_suggester(en_tokenizer): def test_ngram_suggester(en_tokenizer):