diff --git a/spacy/pipeline/spancat_exclusive.py b/spacy/pipeline/spancat_exclusive.py index fa6d8ae60..5605a40f5 100644 --- a/spacy/pipeline/spancat_exclusive.py +++ b/spacy/pipeline/spancat_exclusive.py @@ -266,26 +266,27 @@ class Exclusive_SpanCategorizer(SpanCategorizer): ) -> SpanGroup: scores = self.model.ops.to_numpy(scores) indices = self.model.ops.to_numpy(indices) - if scores.size != 0: - predicted = scores.argmax(axis=1) - # Remove samples where the negative label is the argmax - positive = numpy.where(predicted != self._negative_label)[0] - predicted = predicted[positive] - indices = indices[positive] + # Handle cases when there are zero suggestions + if scores.size == 0: + return SpanGroup(doc, name=self.key) - # Sort spans according to argmax probability - if not allow_overlap and predicted.size != 0: - # Get the probabilities - argmax_probs = numpy.take_along_axis( - scores[positive], numpy.expand_dims(predicted, 1), axis=1 - ) - argmax_probs = argmax_probs.squeeze() - sort_idx = (argmax_probs * -1).argsort() - predicted = predicted[sort_idx] - indices = indices[sort_idx] - else: - predicted = [] + predicted = scores.argmax(axis=1) + # Remove samples where the negative label is the argmax + positive = numpy.where(predicted != self._negative_label)[0] + predicted = predicted[positive] + indices = indices[positive] + + # Sort spans according to argmax probability + if not allow_overlap and predicted.size != 0: + # Get the probabilities + argmax_probs = numpy.take_along_axis( + scores[positive], numpy.expand_dims(predicted, 1), axis=1 + ) + argmax_probs = argmax_probs.squeeze() + sort_idx = (argmax_probs * -1).argsort() + predicted = predicted[sort_idx] + indices = indices[sort_idx] seen = Ranges() spans = SpanGroup(doc, name=self.key)