mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-14 21:57:15 +03:00
Fix spancat-singlelabel score (#12469)
* debug argmax sort and add span scores * add missing tests for spanscores
This commit is contained in:
parent
888332dfb2
commit
26da226a39
|
@ -726,6 +726,7 @@ class SpanCategorizer(TrainablePipe):
|
||||||
if not allow_overlap:
|
if not allow_overlap:
|
||||||
# Get the probabilities
|
# Get the probabilities
|
||||||
sort_idx = (argmax_scores.squeeze() * -1).argsort()
|
sort_idx = (argmax_scores.squeeze() * -1).argsort()
|
||||||
|
argmax_scores = argmax_scores[sort_idx]
|
||||||
predicted = predicted[sort_idx]
|
predicted = predicted[sort_idx]
|
||||||
indices = indices[sort_idx]
|
indices = indices[sort_idx]
|
||||||
keeps = keeps[sort_idx]
|
keeps = keeps[sort_idx]
|
||||||
|
@ -748,4 +749,5 @@ class SpanCategorizer(TrainablePipe):
|
||||||
attrs_scores.append(argmax_scores[i])
|
attrs_scores.append(argmax_scores[i])
|
||||||
spans.append(Span(doc, start, end, label=self.labels[label]))
|
spans.append(Span(doc, start, end, label=self.labels[label]))
|
||||||
|
|
||||||
|
spans.attrs["scores"] = numpy.array(attrs_scores)
|
||||||
return spans
|
return spans
|
||||||
|
|
|
@ -190,17 +190,19 @@ def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results):
|
||||||
spangroup = spancat._make_span_group_singlelabel(
|
spangroup = spancat._make_span_group_singlelabel(
|
||||||
doc, indices, scores, allow_overlap
|
doc, indices, scores, allow_overlap
|
||||||
)
|
)
|
||||||
assert len(spangroup) == nr_results
|
|
||||||
if threshold > 0.4:
|
if threshold > 0.4:
|
||||||
if allow_overlap:
|
if allow_overlap:
|
||||||
assert spangroup[0].text == "London"
|
assert spangroup[0].text == "London"
|
||||||
assert spangroup[0].label_ == "City"
|
assert spangroup[0].label_ == "City"
|
||||||
|
assert_almost_equal(0.6, spangroup.attrs["scores"][0], 5)
|
||||||
assert spangroup[1].text == "Greater London"
|
assert spangroup[1].text == "Greater London"
|
||||||
assert spangroup[1].label_ == "GreatCity"
|
assert spangroup[1].label_ == "GreatCity"
|
||||||
|
assert spangroup.attrs["scores"][1] == 0.9
|
||||||
|
assert_almost_equal(0.9, spangroup.attrs["scores"][1], 5)
|
||||||
else:
|
else:
|
||||||
assert spangroup[0].text == "Greater London"
|
assert spangroup[0].text == "Greater London"
|
||||||
assert spangroup[0].label_ == "GreatCity"
|
assert spangroup[0].label_ == "GreatCity"
|
||||||
|
assert spangroup.attrs["scores"][0] == 0.9
|
||||||
else:
|
else:
|
||||||
if allow_overlap:
|
if allow_overlap:
|
||||||
assert spangroup[0].text == "Greater"
|
assert spangroup[0].text == "Greater"
|
||||||
|
@ -256,22 +258,32 @@ def test_make_spangroup_negative_label():
|
||||||
assert len(spangroup_single) == 2
|
assert len(spangroup_single) == 2
|
||||||
assert spangroup_single[0].text == "Greater"
|
assert spangroup_single[0].text == "Greater"
|
||||||
assert spangroup_single[0].label_ == "City"
|
assert spangroup_single[0].label_ == "City"
|
||||||
|
assert_almost_equal(0.4, spangroup_single.attrs["scores"][0], 5)
|
||||||
assert spangroup_single[1].text == "Greater London"
|
assert spangroup_single[1].text == "Greater London"
|
||||||
assert spangroup_single[1].label_ == "GreatCity"
|
assert spangroup_single[1].label_ == "GreatCity"
|
||||||
|
assert spangroup_single.attrs["scores"][1] == 0.9
|
||||||
|
assert_almost_equal(0.9, spangroup_single.attrs["scores"][1], 5)
|
||||||
|
|
||||||
assert len(spangroup_multi) == 6
|
assert len(spangroup_multi) == 6
|
||||||
assert spangroup_multi[0].text == "Greater"
|
assert spangroup_multi[0].text == "Greater"
|
||||||
assert spangroup_multi[0].label_ == "City"
|
assert spangroup_multi[0].label_ == "City"
|
||||||
|
assert_almost_equal(0.4, spangroup_multi.attrs["scores"][0], 5)
|
||||||
assert spangroup_multi[1].text == "Greater"
|
assert spangroup_multi[1].text == "Greater"
|
||||||
assert spangroup_multi[1].label_ == "Person"
|
assert spangroup_multi[1].label_ == "Person"
|
||||||
|
assert_almost_equal(0.3, spangroup_multi.attrs["scores"][1], 5)
|
||||||
assert spangroup_multi[2].text == "London"
|
assert spangroup_multi[2].text == "London"
|
||||||
assert spangroup_multi[2].label_ == "City"
|
assert spangroup_multi[2].label_ == "City"
|
||||||
|
assert_almost_equal(0.6, spangroup_multi.attrs["scores"][2], 5)
|
||||||
assert spangroup_multi[3].text == "London"
|
assert spangroup_multi[3].text == "London"
|
||||||
assert spangroup_multi[3].label_ == "GreatCity"
|
assert spangroup_multi[3].label_ == "GreatCity"
|
||||||
|
assert_almost_equal(0.4, spangroup_multi.attrs["scores"][3], 5)
|
||||||
assert spangroup_multi[4].text == "Greater London"
|
assert spangroup_multi[4].text == "Greater London"
|
||||||
assert spangroup_multi[4].label_ == "Thing"
|
assert spangroup_multi[4].label_ == "Thing"
|
||||||
|
assert spangroup_multi[4].text == "Greater London"
|
||||||
|
assert_almost_equal(0.8, spangroup_multi.attrs["scores"][4], 5)
|
||||||
assert spangroup_multi[5].text == "Greater London"
|
assert spangroup_multi[5].text == "Greater London"
|
||||||
assert spangroup_multi[5].label_ == "GreatCity"
|
assert spangroup_multi[5].label_ == "GreatCity"
|
||||||
|
assert_almost_equal(0.9, spangroup_multi.attrs["scores"][5], 5)
|
||||||
|
|
||||||
|
|
||||||
def test_ngram_suggester(en_tokenizer):
|
def test_ngram_suggester(en_tokenizer):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user