refactor make_span_group

This commit is contained in:
kadarakos 2023-02-10 14:06:56 +00:00
parent a281a7c9a1
commit a07aafc28e

View File

@ -182,7 +182,7 @@ def make_spancat(
name=name, name=name,
spans_key=spans_key, spans_key=spans_key,
negative_weight=None, negative_weight=None,
allow_overlap=None, allow_overlap=True,
max_positive=max_positive, max_positive=max_positive,
threshold=threshold, threshold=threshold,
scorer=scorer, scorer=scorer,
@ -303,16 +303,16 @@ class SpanCategorizer(TrainablePipe):
vocab: Vocab, vocab: Vocab,
model: Model[Tuple[List[Doc], Ragged], Floats2d], model: Model[Tuple[List[Doc], Ragged], Floats2d],
suggester: Suggester, suggester: Suggester,
# XXX Not sure what's the best default name when it can both be spancat # TODO Not sure what's the best default name when it can both be spancat
# and spancat_singlelabel # and spancat_singlelabel
name: str = "spancat", name: str = "spancat",
*, *,
add_negative_label: bool = False, add_negative_label: bool = False,
spans_key: str = "spans", spans_key: str = "spans",
negative_weight: Optional[float] = None, negative_weight: Optional[float] = 1.0,
allow_overlap: Optional[bool] = None, allow_overlap: Optional[bool] = True,
max_positive: Optional[int] = None, max_positive: Optional[int] = None,
threshold: Optional[float] = None, threshold: Optional[float] = 0.5,
scorer: Optional[Callable] = spancat_score, scorer: Optional[Callable] = spancat_score,
) -> None: ) -> None:
"""Initialize the multilabel or multiclass span categorizer. """Initialize the multilabel or multiclass span categorizer.
@ -366,9 +366,7 @@ class SpanCategorizer(TrainablePipe):
self.name = name self.name = name
self.scorer = scorer self.scorer = scorer
self.add_negative_label = add_negative_label self.add_negative_label = add_negative_label
if allow_overlap is None: if not allow_overlap and max_positive is not None and max_positive > 1:
self.cfg["allow_overlap"] = True
elif not allow_overlap and max_positive is not None and max_positive > 1:
self.cfg["allow_overlap"] = True self.cfg["allow_overlap"] = True
msg.warn( msg.warn(
"'allow_overlap' can only be False when max_positive=1, " "'allow_overlap' can only be False when max_positive=1, "
@ -446,7 +444,7 @@ class SpanCategorizer(TrainablePipe):
return len(self.labels) return len(self.labels)
@property @property
def _negative_label(self) -> Union[int, None]: def _negative_label_i(self) -> Union[int, None]:
"""RETURNS (Union[int, None]): Index of the negative label.""" """RETURNS (Union[int, None]): Index of the negative label."""
if self.add_negative_label: if self.add_negative_label:
return len(self.label_data) return len(self.label_data)
@ -494,7 +492,6 @@ class SpanCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/spancategorizer#set_annotations DOCS: https://spacy.io/api/spancategorizer#set_annotations
""" """
labels = self.labels
indices, scores = indices_scores indices, scores = indices_scores
offset = 0 offset = 0
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
@ -506,7 +503,6 @@ class SpanCategorizer(TrainablePipe):
doc, doc,
indices_i, indices_i,
scores[offset : offset + indices.lengths[i]], scores[offset : offset + indices.lengths[i]],
labels, # type: ignore[arg-type]
allow_overlap, allow_overlap,
) )
else: else:
@ -514,7 +510,6 @@ class SpanCategorizer(TrainablePipe):
doc, doc,
indices_i, indices_i,
scores[offset : offset + indices.lengths[i]], scores[offset : offset + indices.lengths[i]],
labels, # type: ignore[arg-type]
) )
offset += indices.lengths[i] offset += indices.lengths[i]
@ -604,7 +599,7 @@ class SpanCategorizer(TrainablePipe):
target = self.model.ops.asarray(target, dtype="f") # type: ignore target = self.model.ops.asarray(target, dtype="f") # type: ignore
if self.add_negative_label: if self.add_negative_label:
negative_samples = numpy.nonzero(negative_spans)[0] negative_samples = numpy.nonzero(negative_spans)[0]
target[negative_samples, self._negative_label] = 1.0 # type: ignore target[negative_samples, self._negative_label_i] = 1.0 # type: ignore
# The target will have the values 0 (for untrue predictions) or 1 # The target will have the values 0 (for untrue predictions) or 1
# (for true predictions). # (for true predictions).
# The scores should be in the range [0, 1]. # The scores should be in the range [0, 1].
@ -672,9 +667,8 @@ class SpanCategorizer(TrainablePipe):
doc: Doc, doc: Doc,
indices: Ints2d, indices: Ints2d,
scores: Floats2d, scores: Floats2d,
labels: List[str],
) -> SpanGroup: ) -> SpanGroup:
# Handle cases when there are zero suggestions """Find the top-k labels for each span (k=max_positive)."""
spans = SpanGroup(doc, name=self.key) spans = SpanGroup(doc, name=self.key)
if scores.size == 0: if scores.size == 0:
return spans return spans
@ -684,9 +678,15 @@ class SpanCategorizer(TrainablePipe):
max_positive = self.cfg["max_positive"] max_positive = self.cfg["max_positive"]
keeps = scores >= threshold keeps = scores >= threshold
ranked = (scores * -1).argsort() # type: ignore
if max_positive is not None: if max_positive is not None:
assert isinstance(max_positive, int) assert isinstance(max_positive, int)
if self.add_negative_label:
negative_scores = numpy.copy(scores[:, self._negative_label_i])
scores[:, self._negative_label_i] = -numpy.inf
ranked = (scores * -1).argsort() # type: ignore
scores[:, self._negative_label_i] = negative_scores
else:
ranked = (scores * -1).argsort() # type: ignore
span_filter = ranked[:, max_positive:] span_filter = ranked[:, max_positive:]
for i, row in enumerate(span_filter): for i, row in enumerate(span_filter):
keeps[i, row] = False keeps[i, row] = False
@ -697,11 +697,8 @@ class SpanCategorizer(TrainablePipe):
end = indices[i, 1] end = indices[i, 1]
for j, keep in enumerate(keeps[i]): for j, keep in enumerate(keeps[i]):
if keep: if keep:
# If the predicted label is the negative label skip it. if j != self._negative_label_i:
if self.add_negative_label and labels[j] == self._negative_label: spans.append(Span(doc, start, end, label=self.labels[j]))
continue
else:
spans.append(Span(doc, start, end, label=labels[j]))
attrs_scores.append(scores[i, j]) attrs_scores.append(scores[i, j])
spans.attrs["scores"] = numpy.array(attrs_scores) spans.attrs["scores"] = numpy.array(attrs_scores)
return spans return spans
@ -711,9 +708,9 @@ class SpanCategorizer(TrainablePipe):
doc: Doc, doc: Doc,
indices: Ints2d, indices: Ints2d,
scores: Floats2d, scores: Floats2d,
labels: List[str],
allow_overlap: bool = True, allow_overlap: bool = True,
) -> SpanGroup: ) -> SpanGroup:
"""Find the argmax label for each span."""
# Handle cases when there are zero suggestions # Handle cases when there are zero suggestions
spans = SpanGroup(doc, name=self.key) spans = SpanGroup(doc, name=self.key)
if scores.size == 0: if scores.size == 0:
@ -727,7 +724,7 @@ class SpanCategorizer(TrainablePipe):
keeps = numpy.ones(predicted.shape, dtype=bool) keeps = numpy.ones(predicted.shape, dtype=bool)
# Remove samples where the negative label is the argmax. # Remove samples where the negative label is the argmax.
if self.add_negative_label: if self.add_negative_label:
keeps = numpy.logical_and(keeps, predicted != self._negative_label) keeps = numpy.logical_and(keeps, predicted != self._negative_label_i)
# Filter samples according to threshold. # Filter samples according to threshold.
threshold = self.cfg["threshold"] threshold = self.cfg["threshold"]
if threshold is not None: if threshold is not None:
@ -743,11 +740,12 @@ class SpanCategorizer(TrainablePipe):
spans = SpanGroup(doc, name=self.key) spans = SpanGroup(doc, name=self.key)
attrs_scores = [] attrs_scores = []
for i in range(indices.shape[0]): for i in range(indices.shape[0]):
if not keeps[i]:
continue
label = predicted[i] label = predicted[i]
start = indices[i, 0] start = indices[i, 0]
end = indices[i, 1] end = indices[i, 1]
if not keeps[i]:
continue
if not allow_overlap: if not allow_overlap:
if (start, end) in seen: if (start, end) in seen:
@ -755,6 +753,6 @@ class SpanCategorizer(TrainablePipe):
else: else:
seen.add(start, end) seen.add(start, end)
attrs_scores.append(argmax_scores[i]) attrs_scores.append(argmax_scores[i])
spans.append(Span(doc, start, end, label=labels[label])) spans.append(Span(doc, start, end, label=self.labels[label]))
return spans return spans