debug argmax sort and add span scores

This commit is contained in:
kadarakos 2023-03-27 09:04:02 +00:00
parent 28de85737f
commit b0e6d698f6

View File

@ -726,6 +726,7 @@ class SpanCategorizer(TrainablePipe):
if not allow_overlap:
# Get the probabilities
sort_idx = (argmax_scores.squeeze() * -1).argsort()
argmax_scores = argmax_scores[sort_idx]
predicted = predicted[sort_idx]
indices = indices[sort_idx]
keeps = keeps[sort_idx]
@ -748,4 +749,5 @@ class SpanCategorizer(TrainablePipe):
attrs_scores.append(argmax_scores[i])
spans.append(Span(doc, start, end, label=self.labels[label]))
spans.attrs["scores"] = numpy.array(attrs_scores)
return spans