diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index b5872daf3..9a4eca4db 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -731,25 +731,23 @@ class SpanCategorizer(TrainablePipe): # Filter samples according to threshold. threshold = self.cfg["threshold"] if threshold is not None: - print(argmax_scores >= threshold) - keeps = numpy.logical_and(keeps, argmax_scores >= threshold) + keeps = numpy.logical_and(keeps, (argmax_scores >= threshold).squeeze()) # Sort spans according to argmax probability if not allow_overlap: # Get the probabilities - sort_idx = (argmax_scores * -1).argsort() + sort_idx = (argmax_scores.squeeze() * -1).argsort() predicted = predicted[sort_idx] indices = indices[sort_idx] - - # TODO assigns spans.attrs["scores"] + keeps = keeps[sort_idx] seen = Intervals() spans = SpanGroup(doc, name=self.key) attrs_scores = [] for i in range(indices.shape[0]): - if not keeps[i]: - continue label = predicted[i] start = indices[i, 0] end = indices[i, 1] + if not keeps[i]: + continue if not allow_overlap: if (start, end) in seen: