diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index d1dd2b9d5..b5872daf3 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -690,13 +690,11 @@ class SpanCategorizer(TrainablePipe): span_filter = ranked[:, max_positive:] for i, row in enumerate(span_filter): keeps[i, row] = False - # TODO I think this is now incorrect - spans.attrs["scores"] = scores[keeps].flatten() + attrs_scores = [] for i in range(indices.shape[0]): start = indices[i, 0] end = indices[i, 1] - for j, keep in enumerate(keeps[i]): if keep: # If the predicted label is the negative label skip it. @@ -704,7 +702,8 @@ class SpanCategorizer(TrainablePipe): continue else: spans.append(Span(doc, start, end, label=labels[j])) - + attrs_scores.append(scores[i, j]) + spans.attrs["scores"] = numpy.array(attrs_scores) return spans def _make_span_group_singlelabel( @@ -721,21 +720,19 @@ class SpanCategorizer(TrainablePipe): return spans scores = self.model.ops.to_numpy(scores) indices = self.model.ops.to_numpy(indices) - threshold = self.cfg["threshold"] predicted = scores.argmax(axis=1) argmax_scores = numpy.take_along_axis( scores, numpy.expand_dims(predicted, 1), axis=1 - ).squeeze() + ) + keeps = numpy.ones(predicted.shape, dtype=bool) # Remove samples where the negative label is the argmax. if self.add_negative_label: - positive = numpy.where(predicted != self._negative_label)[0] - predicted = predicted[positive] - indices = indices[positive] + keeps = numpy.logical_and(keeps, predicted != self._negative_label) # Filter samples according to threshold. + threshold = self.cfg["threshold"] if threshold is not None: - keeps = numpy.where(argmax_scores >= threshold) - predicted = predicted[keeps] - indices = indices[keeps] + print(argmax_scores >= threshold) + keeps = numpy.logical_and(keeps, argmax_scores >= threshold) # Sort spans according to argmax probability if not allow_overlap: # Get the probabilities @@ -746,7 +743,10 @@ class SpanCategorizer(TrainablePipe): # TODO assigns spans.attrs["scores"] seen = Intervals() spans = SpanGroup(doc, name=self.key) - for i in range(len(predicted)): + attrs_scores = [] + for i in range(indices.shape[0]): + if not keeps[i]: + continue label = predicted[i] start = indices[i, 0] end = indices[i, 1] @@ -756,7 +756,7 @@ class SpanCategorizer(TrainablePipe): continue else: seen.add(start, end) - + attrs_scores.append(argmax_scores[i]) spans.append(Span(doc, start, end, label=labels[label])) return spans