From c24b3785a61475e53c3e1ee88171c635d8eee80a Mon Sep 17 00:00:00 2001 From: kadarakos Date: Mon, 6 Feb 2023 18:54:30 +0000 Subject: [PATCH] replace single_label with add_negative_label and adjust inference --- spacy/pipeline/spancat.py | 114 +++++++++++++++++++++----------------- 1 file changed, 63 insertions(+), 51 deletions(-) diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index f2d01e245..3bbb186d4 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Optimizer from thinc.types import Ragged, Ints2d, Floats2d +from wasabi import msg import numpy @@ -179,13 +180,13 @@ def make_spancat( model=model, suggester=suggester, name=name, - single_label=False, spans_key=spans_key, negative_weight=None, allow_overlap=None, max_positive=max_positive, threshold=threshold, scorer=scorer, + add_negative_label=False, ) @@ -242,11 +243,11 @@ def make_spancat_singlelabel( model=model, suggester=suggester, name=name, - single_label=True, spans_key=spans_key, negative_weight=negative_weight, allow_overlap=allow_overlap, - max_positive=None, + max_positive=1, + add_negative_label=True, threshold=None, scorer=scorer, ) @@ -271,7 +272,7 @@ def make_spancat_scorer(): @dataclass -class Ranges: +class Intervals: """ Helper class to avoid storing overlapping spans. """ @@ -306,7 +307,7 @@ class SpanCategorizer(TrainablePipe): # and spancat_singlelabel name: str = "spancat", *, - single_label: bool = False, + add_negative_label: bool = False, spans_key: str = "spans", negative_weight: Optional[float] = None, allow_overlap: Optional[bool] = None, @@ -331,14 +332,11 @@ class SpanCategorizer(TrainablePipe): During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"spans"`. - single_label (bool): Whether to configure SpanCategorizer to produce - a single label per span. In this case its expected that the scorer - layer is Softmax. Otherwise its expected to be Logistic. When single_label - is true the SpanCategorizer internally has a negative-label indicating - that a span should not receive any of the labels found in the corpus. + add_negative_label (bool): Learn to predict a special 'negative_label' + when a Span is not annotated. threshold (Optional[float]): Minimum probability to consider a prediction - positive in the multilabel usecase.Defaults to 0.5 when single_label is - False otherwise its None. Spans with a positive prediction will be saved on the Doc. + positive. Defaults to 0.5. Spans with a positive prediction will be saved + on the Doc. max_positive (Optional[int]): Maximum number of labels to consider positive per span. Defaults to None, indicating no limit. This is unused when single_label is True. @@ -367,7 +365,16 @@ class SpanCategorizer(TrainablePipe): self.model = model self.name = name self.scorer = scorer - self.single_label = single_label + self.add_negative_label = add_negative_label + if allow_overlap is None: + self.cfg["allow_overlap"] = True + elif not allow_overlap and max_positive > 1: + self.cfg["allow_overlap"] = True + msg.warn( + "'allow_overlap' can only be False when max_positive=1, " + f"but found 'max_positive': {max_positive} " + "SpanCategorizer is automatically configured with allow_overlap=True." + ) @property def key(self) -> str: @@ -433,7 +440,7 @@ class SpanCategorizer(TrainablePipe): @property def _n_labels(self) -> int: """RETURNS (int): Number of labels.""" - if self.single_label: + if self.add_negative_label: return len(self.labels) + 1 else: return len(self.labels) @@ -441,7 +448,7 @@ class SpanCategorizer(TrainablePipe): @property def _negative_label(self) -> Union[int, None]: """RETURNS (Union[int, None]): Index of the negative label.""" - if self.single_label: + if self.add_negative_label: return len(self.label_data) else: return None @@ -492,13 +499,9 @@ class SpanCategorizer(TrainablePipe): offset = 0 for i, doc in enumerate(docs): indices_i = indices[i].dataXd - if self.single_label: - allow_overlap = self.cfg["allow_overlap"] - # Interpret None as False if allow_overlap is not provided - if allow_overlap is None: - allow_overlap = False - else: - allow_overlap = cast(bool, self.cfg["allow_overlap"]) + allow_overlap = self.cfg["allow_overlap"] + allow_overlap = cast(bool, self.cfg["allow_overlap"]) + if self.cfg["max_positive"] == 1: doc.spans[self.key] = self._make_span_group_singlelabel( doc, indices_i, @@ -513,7 +516,6 @@ class SpanCategorizer(TrainablePipe): scores[offset : offset + indices.lengths[i]], labels, # type: ignore[arg-type] ) - offset += indices.lengths[i] def update( @@ -574,7 +576,7 @@ class SpanCategorizer(TrainablePipe): self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths) ) target = numpy.zeros(scores.shape, dtype=scores.dtype) - if self.single_label: + if self.add_negative_label: negative_spans = numpy.ones((scores.shape[0])) offset = 0 for i, eg in enumerate(examples): @@ -593,14 +595,14 @@ class SpanCategorizer(TrainablePipe): row = spans_index[key] k = self.label_map[gold_span.label_] target[row, k] = 1.0 - if self.single_label: + if self.add_negative_label: # delete negative label target. negative_spans[row] = 0.0 # The target is a flat array for all docs. Track the position # we're at within the flat array. offset += spans.lengths[i] target = self.model.ops.asarray(target, dtype="f") # type: ignore - if self.single_label: + if self.add_negative_label: negative_samples = numpy.nonzero(negative_spans)[0] target[negative_samples, self._negative_label] = 1.0 # type: ignore # The target will have the values 0 (for untrue predictions) or 1 @@ -611,7 +613,7 @@ class SpanCategorizer(TrainablePipe): # If the prediction is 0.9 and it's false, the gradient will be # 0.9 (0.9 - 0.0) d_scores = scores - target - if self.single_label: + if self.add_negative_label: neg_weight = cast(float, self.cfg["negative_weight"]) if neg_weight != 1.0: d_scores[negative_samples] *= neg_weight @@ -671,12 +673,15 @@ class SpanCategorizer(TrainablePipe): indices: Ints2d, scores: Floats2d, labels: List[str], - # XXX Unused, does it make sense? - allow_overlap: bool = True, ) -> SpanGroup: + # Handle cases when there are zero suggestions spans = SpanGroup(doc, name=self.key) - max_positive = self.cfg["max_positive"] + if scores.size == 0: + return spans + scores = self.model.ops.to_numpy(scores) + indices = self.model.ops.to_numpy(indices) threshold = self.cfg["threshold"] + max_positive = self.cfg["max_positive"] keeps = scores >= threshold ranked = (scores * -1).argsort() # type: ignore @@ -685,19 +690,20 @@ class SpanCategorizer(TrainablePipe): span_filter = ranked[:, max_positive:] for i, row in enumerate(span_filter): keeps[i, row] = False - + # TODO I think this is now incorrect spans.attrs["scores"] = scores[keeps].flatten() - indices = self.model.ops.to_numpy(indices) - keeps = self.model.ops.to_numpy(keeps) - for i in range(indices.shape[0]): start = indices[i, 0] end = indices[i, 1] for j, keep in enumerate(keeps[i]): if keep: - spans.append(Span(doc, start, end, label=labels[j])) + # If the predicted label is the negative label skip it. + if self.add_negative_label and labels[j] == self._negative_label: + continue + else: + spans.append(Span(doc, start, end, label=labels[j])) return spans @@ -709,30 +715,36 @@ class SpanCategorizer(TrainablePipe): labels: List[str], allow_overlap: bool = True, ) -> SpanGroup: + # Handle cases when there are zero suggestions + spans = SpanGroup(doc, name=self.key) + if scores.size == 0: + return spans scores = self.model.ops.to_numpy(scores) indices = self.model.ops.to_numpy(indices) - # Handle cases when there are zero suggestions - if scores.size == 0: - return SpanGroup(doc, name=self.key) - + threshold = self.cfg["threshold"] predicted = scores.argmax(axis=1) - # Remove samples where the negative label is the argmax - positive = numpy.where(predicted != self._negative_label)[0] - predicted = predicted[positive] - indices = indices[positive] - + argmax_scores = numpy.take_along_axis( + scores, numpy.expand_dims(predicted, 1), axis=1 + ).squeeze() + # Remove samples where the negative label is the argmax. + if self.add_negative_label: + positive = numpy.where(predicted != self._negative_label)[0] + predicted = predicted[positive] + indices = indices[positive] + # Filter samples according to threshold. + if threshold is not None: + keeps = numpy.where(argmax_scores >= threshold) + predicted = predicted[keeps] + indices = indices[keeps] # Sort spans according to argmax probability - if not allow_overlap and predicted.size != 0: + if not allow_overlap: # Get the probabilities - argmax_probs = numpy.take_along_axis( - scores[positive], numpy.expand_dims(predicted, 1), axis=1 - ) - argmax_probs = argmax_probs.squeeze() - sort_idx = (argmax_probs * -1).argsort() + sort_idx = (argmax_scores * -1).argsort() predicted = predicted[sort_idx] indices = indices[sort_idx] - seen = Ranges() + # TODO assigns spans.attrs["scores"] + seen = Intervals() spans = SpanGroup(doc, name=self.key) for i in range(len(predicted)): label = predicted[i]