From b0e6d698f67cc6da6f0ca793bf7010edd1833eea Mon Sep 17 00:00:00 2001 From: kadarakos Date: Mon, 27 Mar 2023 09:04:02 +0000 Subject: [PATCH] debug argmax sort and add span scores --- spacy/pipeline/spancat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 983e1fba9..ff68a3703 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -726,6 +726,7 @@ class SpanCategorizer(TrainablePipe): if not allow_overlap: # Get the probabilities sort_idx = (argmax_scores.squeeze() * -1).argsort() + argmax_scores = argmax_scores[sort_idx] predicted = predicted[sort_idx] indices = indices[sort_idx] keeps = keeps[sort_idx] @@ -748,4 +749,5 @@ class SpanCategorizer(TrainablePipe): attrs_scores.append(argmax_scores[i]) spans.append(Span(doc, start, end, label=self.labels[label])) + spans.attrs["scores"] = numpy.array(attrs_scores) return spans