Handle zero suggestions to make tests pass

I'm not sure if this is the most elegant solution. But what should
happen is that the _make_span_group function MUST return an empty
SpanGroup if there are no suggestions.

The error happens when the 'scores' variable is empty. We cannot
get the 'predicted' and other downstream vars.
This commit is contained in:
Lj Miranda 2022-12-21 10:36:01 +08:00
parent f476317387
commit a3fad0b983

View File

@ -266,23 +266,26 @@ class Exclusive_SpanCategorizer(SpanCategorizer):
) -> SpanGroup:
scores = self.model.ops.to_numpy(scores)
indices = self.model.ops.to_numpy(indices)
predicted = scores.argmax(axis=1)
if scores.size != 0:
predicted = scores.argmax(axis=1)
# Remove samples where the negative label is the argmax
positive = numpy.where(predicted != self._negative_label)[0]
predicted = predicted[positive]
indices = indices[positive]
# Remove samples where the negative label is the argmax
positive = numpy.where(predicted != self._negative_label)[0]
predicted = predicted[positive]
indices = indices[positive]
# Sort spans according to argmax probability
if not allow_overlap:
# Get the probabilities
argmax_probs = numpy.take_along_axis(
scores[positive], numpy.expand_dims(predicted, 1), axis=1
)
argmax_probs = argmax_probs.squeeze()
sort_idx = (argmax_probs * -1).argsort()
predicted = predicted[sort_idx]
indices = indices[sort_idx]
# Sort spans according to argmax probability
if not allow_overlap and predicted.size != 0:
# Get the probabilities
argmax_probs = numpy.take_along_axis(
scores[positive], numpy.expand_dims(predicted, 1), axis=1
)
argmax_probs = argmax_probs.squeeze()
sort_idx = (argmax_probs * -1).argsort()
predicted = predicted[sort_idx]
indices = indices[sort_idx]
else:
predicted = []
seen = Ranges()
spans = SpanGroup(doc, name=self.key)