diff --git a/spacy/pipeline/spancat_exclusive.py b/spacy/pipeline/spancat_exclusive.py index ed5f3272e..fa6d8ae60 100644 --- a/spacy/pipeline/spancat_exclusive.py +++ b/spacy/pipeline/spancat_exclusive.py @@ -266,23 +266,26 @@ class Exclusive_SpanCategorizer(SpanCategorizer): ) -> SpanGroup: scores = self.model.ops.to_numpy(scores) indices = self.model.ops.to_numpy(indices) - predicted = scores.argmax(axis=1) + if scores.size != 0: + predicted = scores.argmax(axis=1) - # Remove samples where the negative label is the argmax - positive = numpy.where(predicted != self._negative_label)[0] - predicted = predicted[positive] - indices = indices[positive] + # Remove samples where the negative label is the argmax + positive = numpy.where(predicted != self._negative_label)[0] + predicted = predicted[positive] + indices = indices[positive] - # Sort spans according to argmax probability - if not allow_overlap: - # Get the probabilities - argmax_probs = numpy.take_along_axis( - scores[positive], numpy.expand_dims(predicted, 1), axis=1 - ) - argmax_probs = argmax_probs.squeeze() - sort_idx = (argmax_probs * -1).argsort() - predicted = predicted[sort_idx] - indices = indices[sort_idx] + # Sort spans according to argmax probability + if not allow_overlap and predicted.size != 0: + # Get the probabilities + argmax_probs = numpy.take_along_axis( + scores[positive], numpy.expand_dims(predicted, 1), axis=1 + ) + argmax_probs = argmax_probs.squeeze() + sort_idx = (argmax_probs * -1).argsort() + predicted = predicted[sort_idx] + indices = indices[sort_idx] + else: + predicted = [] seen = Ranges() spans = SpanGroup(doc, name=self.key)