diff --git a/spacy/pipeline/spancat_exclusive.py b/spacy/pipeline/spancat_exclusive.py index 9a52fc1d4..b5af27265 100644 --- a/spacy/pipeline/spancat_exclusive.py +++ b/spacy/pipeline/spancat_exclusive.py @@ -215,17 +215,18 @@ class SpanCategorizerExclusive(TrainablePipe): return list(self.labels) @property - def _negative_label(self): - """ - Index of the negative label. - """ + def label_map(self) -> Dict[str, int]: + """RETURNS (Dict[str, int]): The label map.""" + return {label: i for i, label in enumerate(self.labels)} + + @property + def _negative_label(self) -> int: + """RETURNS (int): Index of the negative label.""" return len(self.label_data) @property - def _n_labels(self): - """ - Number of labels including the negative label. - """ + def _n_labels(self) -> int: + """RETURNS (int): Number of labels including the negative label.""" return len(self.label_data) + 1 def predict(self, docs: Iterable[Doc]): @@ -339,7 +340,6 @@ class SpanCategorizerExclusive(TrainablePipe): spans = Ragged( self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths) ) - label_map = {label: i for i, label in enumerate(self.labels)} target = numpy.zeros(scores.shape, dtype=scores.dtype) # Set negative class as target initially for all samples. negative_spans = numpy.ones((scores.shape[0])) @@ -358,7 +358,7 @@ class SpanCategorizerExclusive(TrainablePipe): key = (gold_span.start, gold_span.end) if key in spans_index: row = spans_index[key] - k = label_map[gold_span.label_] + k = self.label_map[gold_span.label_] target[row, k] = 1.0 # delete negative label target. negative_spans[row] = 0.0 @@ -438,14 +438,15 @@ class SpanCategorizerExclusive(TrainablePipe): predicted = scores.argmax(axis=1) # Remove samples where the negative label is the argmax - positive = numpy.where(predicted != self._negative_label) - predicted = predicted[positive[0]] - indices = indices[positive[0]] + positive = numpy.where(predicted != self._negative_label)[0] + predicted = predicted[positive] + indices = indices[positive] # Sort spans according to argmax probability if not allow_overlap: + # Get the probabilities argmax_probs = numpy.take_along_axis( - scores[positive[0]], numpy.expand_dims(predicted, 1), axis=1 + scores[positive], numpy.expand_dims(predicted, 1), axis=1 ) argmax_probs = argmax_probs.squeeze() sort_idx = (argmax_probs * -1).argsort()