mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-24 19:11:58 +03:00
Handle zero suggestions to make tests pass
I'm not sure if this is the most elegant solution. But what should happen is that the _make_span_group function MUST return an empty SpanGroup if there are no suggestions. The error happens when the 'scores' variable is empty. We cannot get the 'predicted' and other downstream vars.
This commit is contained in:
parent
f476317387
commit
a3fad0b983
|
@ -266,23 +266,26 @@ class Exclusive_SpanCategorizer(SpanCategorizer):
|
|||
) -> SpanGroup:
|
||||
scores = self.model.ops.to_numpy(scores)
|
||||
indices = self.model.ops.to_numpy(indices)
|
||||
predicted = scores.argmax(axis=1)
|
||||
if scores.size != 0:
|
||||
predicted = scores.argmax(axis=1)
|
||||
|
||||
# Remove samples where the negative label is the argmax
|
||||
positive = numpy.where(predicted != self._negative_label)[0]
|
||||
predicted = predicted[positive]
|
||||
indices = indices[positive]
|
||||
# Remove samples where the negative label is the argmax
|
||||
positive = numpy.where(predicted != self._negative_label)[0]
|
||||
predicted = predicted[positive]
|
||||
indices = indices[positive]
|
||||
|
||||
# Sort spans according to argmax probability
|
||||
if not allow_overlap:
|
||||
# Get the probabilities
|
||||
argmax_probs = numpy.take_along_axis(
|
||||
scores[positive], numpy.expand_dims(predicted, 1), axis=1
|
||||
)
|
||||
argmax_probs = argmax_probs.squeeze()
|
||||
sort_idx = (argmax_probs * -1).argsort()
|
||||
predicted = predicted[sort_idx]
|
||||
indices = indices[sort_idx]
|
||||
# Sort spans according to argmax probability
|
||||
if not allow_overlap and predicted.size != 0:
|
||||
# Get the probabilities
|
||||
argmax_probs = numpy.take_along_axis(
|
||||
scores[positive], numpy.expand_dims(predicted, 1), axis=1
|
||||
)
|
||||
argmax_probs = argmax_probs.squeeze()
|
||||
sort_idx = (argmax_probs * -1).argsort()
|
||||
predicted = predicted[sort_idx]
|
||||
indices = indices[sort_idx]
|
||||
else:
|
||||
predicted = []
|
||||
|
||||
seen = Ranges()
|
||||
spans = SpanGroup(doc, name=self.key)
|
||||
|
|
Loading…
Reference in New Issue
Block a user