diff --git a/spacy/pipeline/spancat_exclusive.py b/spacy/pipeline/spancat_exclusive.py index 3d169b199..03bcc7b83 100644 --- a/spacy/pipeline/spancat_exclusive.py +++ b/spacy/pipeline/spancat_exclusive.py @@ -272,7 +272,7 @@ class SpanCategorizerExclusive(TrainablePipe): """Modify a batch of Doc objects, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. - scores: The scores to set, produced by SpanCategorizer.predict. + scores: The scores to set, produced by SpanCategorizerExclusive.predict. DOCS: https://spacy.io/api/spancategorizerexclusive#set_annotations """ @@ -446,10 +446,12 @@ class SpanCategorizerExclusive(TrainablePipe): scores = self.model.ops.to_numpy(scores) indices = self.model.ops.to_numpy(indices) predicted = scores.argmax(axis=1) + # Remove samples where the negative label is the argmax positive = numpy.where(predicted != self._negative_label) predicted = predicted[positive[0]] indices = indices[positive[0]] + # Sort spans according to argmax probability if not allow_overlap: argmax_probs = numpy.take_along_axis( @@ -459,16 +461,20 @@ class SpanCategorizerExclusive(TrainablePipe): sort_idx = (argmax_probs * -1).argsort() predicted = predicted[sort_idx] indices = indices[sort_idx] + seen = Ranges() spans = SpanGroup(doc, name=self.key) for i in range(len(predicted)): label = predicted[i] start = indices[i, 0] end = indices[i, 1] + if not allow_overlap: if (start, end) in seen: continue else: seen.add(start, end) + spans.append(Span(doc, start, end, label=labels[label])) + return spans