From a07aafc28e12e92e3bd6577cc3a9a945ccc203e3 Mon Sep 17 00:00:00 2001 From: kadarakos Date: Fri, 10 Feb 2023 14:06:56 +0000 Subject: [PATCH] refactor make_span_group --- spacy/pipeline/spancat.py | 50 +++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 9a4eca4db..1de51bb09 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -182,7 +182,7 @@ def make_spancat( name=name, spans_key=spans_key, negative_weight=None, - allow_overlap=None, + allow_overlap=True, max_positive=max_positive, threshold=threshold, scorer=scorer, @@ -303,16 +303,16 @@ class SpanCategorizer(TrainablePipe): vocab: Vocab, model: Model[Tuple[List[Doc], Ragged], Floats2d], suggester: Suggester, - # XXX Not sure what's the best default name when it can both be spancat + # TODO Not sure what's the best default name when it can both be spancat # and spancat_singlelabel name: str = "spancat", *, add_negative_label: bool = False, spans_key: str = "spans", - negative_weight: Optional[float] = None, - allow_overlap: Optional[bool] = None, + negative_weight: Optional[float] = 1.0, + allow_overlap: Optional[bool] = True, max_positive: Optional[int] = None, - threshold: Optional[float] = None, + threshold: Optional[float] = 0.5, scorer: Optional[Callable] = spancat_score, ) -> None: """Initialize the multilabel or multiclass span categorizer. @@ -366,9 +366,7 @@ class SpanCategorizer(TrainablePipe): self.name = name self.scorer = scorer self.add_negative_label = add_negative_label - if allow_overlap is None: - self.cfg["allow_overlap"] = True - elif not allow_overlap and max_positive is not None and max_positive > 1: + if not allow_overlap and max_positive is not None and max_positive > 1: self.cfg["allow_overlap"] = True msg.warn( "'allow_overlap' can only be False when max_positive=1, " @@ -446,7 +444,7 @@ class SpanCategorizer(TrainablePipe): return len(self.labels) @property - def _negative_label(self) -> Union[int, None]: + def _negative_label_i(self) -> Union[int, None]: """RETURNS (Union[int, None]): Index of the negative label.""" if self.add_negative_label: return len(self.label_data) @@ -494,7 +492,6 @@ class SpanCategorizer(TrainablePipe): DOCS: https://spacy.io/api/spancategorizer#set_annotations """ - labels = self.labels indices, scores = indices_scores offset = 0 for i, doc in enumerate(docs): @@ -506,7 +503,6 @@ class SpanCategorizer(TrainablePipe): doc, indices_i, scores[offset : offset + indices.lengths[i]], - labels, # type: ignore[arg-type] allow_overlap, ) else: @@ -514,7 +510,6 @@ class SpanCategorizer(TrainablePipe): doc, indices_i, scores[offset : offset + indices.lengths[i]], - labels, # type: ignore[arg-type] ) offset += indices.lengths[i] @@ -604,7 +599,7 @@ class SpanCategorizer(TrainablePipe): target = self.model.ops.asarray(target, dtype="f") # type: ignore if self.add_negative_label: negative_samples = numpy.nonzero(negative_spans)[0] - target[negative_samples, self._negative_label] = 1.0 # type: ignore + target[negative_samples, self._negative_label_i] = 1.0 # type: ignore # The target will have the values 0 (for untrue predictions) or 1 # (for true predictions). # The scores should be in the range [0, 1]. @@ -672,9 +667,8 @@ class SpanCategorizer(TrainablePipe): doc: Doc, indices: Ints2d, scores: Floats2d, - labels: List[str], ) -> SpanGroup: - # Handle cases when there are zero suggestions + """Find the top-k labels for each span (k=max_positive).""" spans = SpanGroup(doc, name=self.key) if scores.size == 0: return spans @@ -684,9 +678,15 @@ class SpanCategorizer(TrainablePipe): max_positive = self.cfg["max_positive"] keeps = scores >= threshold - ranked = (scores * -1).argsort() # type: ignore if max_positive is not None: assert isinstance(max_positive, int) + if self.add_negative_label: + negative_scores = numpy.copy(scores[:, self._negative_label_i]) + scores[:, self._negative_label_i] = -numpy.inf + ranked = (scores * -1).argsort() # type: ignore + scores[:, self._negative_label_i] = negative_scores + else: + ranked = (scores * -1).argsort() # type: ignore span_filter = ranked[:, max_positive:] for i, row in enumerate(span_filter): keeps[i, row] = False @@ -697,11 +697,8 @@ class SpanCategorizer(TrainablePipe): end = indices[i, 1] for j, keep in enumerate(keeps[i]): if keep: - # If the predicted label is the negative label skip it. - if self.add_negative_label and labels[j] == self._negative_label: - continue - else: - spans.append(Span(doc, start, end, label=labels[j])) + if j != self._negative_label_i: + spans.append(Span(doc, start, end, label=self.labels[j])) attrs_scores.append(scores[i, j]) spans.attrs["scores"] = numpy.array(attrs_scores) return spans @@ -711,9 +708,9 @@ class SpanCategorizer(TrainablePipe): doc: Doc, indices: Ints2d, scores: Floats2d, - labels: List[str], allow_overlap: bool = True, ) -> SpanGroup: + """Find the argmax label for each span.""" # Handle cases when there are zero suggestions spans = SpanGroup(doc, name=self.key) if scores.size == 0: @@ -727,7 +724,7 @@ class SpanCategorizer(TrainablePipe): keeps = numpy.ones(predicted.shape, dtype=bool) # Remove samples where the negative label is the argmax. if self.add_negative_label: - keeps = numpy.logical_and(keeps, predicted != self._negative_label) + keeps = numpy.logical_and(keeps, predicted != self._negative_label_i) # Filter samples according to threshold. threshold = self.cfg["threshold"] if threshold is not None: @@ -743,11 +740,12 @@ class SpanCategorizer(TrainablePipe): spans = SpanGroup(doc, name=self.key) attrs_scores = [] for i in range(indices.shape[0]): + if not keeps[i]: + continue + label = predicted[i] start = indices[i, 0] end = indices[i, 1] - if not keeps[i]: - continue if not allow_overlap: if (start, end) in seen: @@ -755,6 +753,6 @@ class SpanCategorizer(TrainablePipe): else: seen.add(start, end) attrs_scores.append(argmax_scores[i]) - spans.append(Span(doc, start, end, label=labels[label])) + spans.append(Span(doc, start, end, label=self.labels[label])) return spans