replace single_label with add_negative_label and adjust inference

This commit is contained in:
kadarakos 2023-02-06 18:54:30 +00:00
parent c864f12e28
commit c24b3785a6

View File

@ -3,6 +3,7 @@ from dataclasses import dataclass
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer
from thinc.types import Ragged, Ints2d, Floats2d
from wasabi import msg
import numpy
@ -179,13 +180,13 @@ def make_spancat(
model=model,
suggester=suggester,
name=name,
single_label=False,
spans_key=spans_key,
negative_weight=None,
allow_overlap=None,
max_positive=max_positive,
threshold=threshold,
scorer=scorer,
add_negative_label=False,
)
@ -242,11 +243,11 @@ def make_spancat_singlelabel(
model=model,
suggester=suggester,
name=name,
single_label=True,
spans_key=spans_key,
negative_weight=negative_weight,
allow_overlap=allow_overlap,
max_positive=None,
max_positive=1,
add_negative_label=True,
threshold=None,
scorer=scorer,
)
@ -271,7 +272,7 @@ def make_spancat_scorer():
@dataclass
class Ranges:
class Intervals:
"""
Helper class to avoid storing overlapping spans.
"""
@ -306,7 +307,7 @@ class SpanCategorizer(TrainablePipe):
# and spancat_singlelabel
name: str = "spancat",
*,
single_label: bool = False,
add_negative_label: bool = False,
spans_key: str = "spans",
negative_weight: Optional[float] = None,
allow_overlap: Optional[bool] = None,
@ -331,14 +332,11 @@ class SpanCategorizer(TrainablePipe):
During initialization and training, the component will look for
spans on the reference document under the same key. Defaults to
`"spans"`.
single_label (bool): Whether to configure SpanCategorizer to produce
a single label per span. In this case its expected that the scorer
layer is Softmax. Otherwise its expected to be Logistic. When single_label
is true the SpanCategorizer internally has a negative-label indicating
that a span should not receive any of the labels found in the corpus.
add_negative_label (bool): Learn to predict a special 'negative_label'
when a Span is not annotated.
threshold (Optional[float]): Minimum probability to consider a prediction
positive in the multilabel usecase.Defaults to 0.5 when single_label is
False otherwise its None. Spans with a positive prediction will be saved on the Doc.
positive. Defaults to 0.5. Spans with a positive prediction will be saved
on the Doc.
max_positive (Optional[int]): Maximum number of labels to consider
positive per span. Defaults to None, indicating no limit. This is
unused when single_label is True.
@ -367,7 +365,16 @@ class SpanCategorizer(TrainablePipe):
self.model = model
self.name = name
self.scorer = scorer
self.single_label = single_label
self.add_negative_label = add_negative_label
if allow_overlap is None:
self.cfg["allow_overlap"] = True
elif not allow_overlap and max_positive > 1:
self.cfg["allow_overlap"] = True
msg.warn(
"'allow_overlap' can only be False when max_positive=1, "
f"but found 'max_positive': {max_positive} "
"SpanCategorizer is automatically configured with allow_overlap=True."
)
@property
def key(self) -> str:
@ -433,7 +440,7 @@ class SpanCategorizer(TrainablePipe):
@property
def _n_labels(self) -> int:
"""RETURNS (int): Number of labels."""
if self.single_label:
if self.add_negative_label:
return len(self.labels) + 1
else:
return len(self.labels)
@ -441,7 +448,7 @@ class SpanCategorizer(TrainablePipe):
@property
def _negative_label(self) -> Union[int, None]:
"""RETURNS (Union[int, None]): Index of the negative label."""
if self.single_label:
if self.add_negative_label:
return len(self.label_data)
else:
return None
@ -492,13 +499,9 @@ class SpanCategorizer(TrainablePipe):
offset = 0
for i, doc in enumerate(docs):
indices_i = indices[i].dataXd
if self.single_label:
allow_overlap = self.cfg["allow_overlap"]
# Interpret None as False if allow_overlap is not provided
if allow_overlap is None:
allow_overlap = False
else:
allow_overlap = cast(bool, self.cfg["allow_overlap"])
if self.cfg["max_positive"] == 1:
doc.spans[self.key] = self._make_span_group_singlelabel(
doc,
indices_i,
@ -513,7 +516,6 @@ class SpanCategorizer(TrainablePipe):
scores[offset : offset + indices.lengths[i]],
labels, # type: ignore[arg-type]
)
offset += indices.lengths[i]
def update(
@ -574,7 +576,7 @@ class SpanCategorizer(TrainablePipe):
self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths)
)
target = numpy.zeros(scores.shape, dtype=scores.dtype)
if self.single_label:
if self.add_negative_label:
negative_spans = numpy.ones((scores.shape[0]))
offset = 0
for i, eg in enumerate(examples):
@ -593,14 +595,14 @@ class SpanCategorizer(TrainablePipe):
row = spans_index[key]
k = self.label_map[gold_span.label_]
target[row, k] = 1.0
if self.single_label:
if self.add_negative_label:
# delete negative label target.
negative_spans[row] = 0.0
# The target is a flat array for all docs. Track the position
# we're at within the flat array.
offset += spans.lengths[i]
target = self.model.ops.asarray(target, dtype="f") # type: ignore
if self.single_label:
if self.add_negative_label:
negative_samples = numpy.nonzero(negative_spans)[0]
target[negative_samples, self._negative_label] = 1.0 # type: ignore
# The target will have the values 0 (for untrue predictions) or 1
@ -611,7 +613,7 @@ class SpanCategorizer(TrainablePipe):
# If the prediction is 0.9 and it's false, the gradient will be
# 0.9 (0.9 - 0.0)
d_scores = scores - target
if self.single_label:
if self.add_negative_label:
neg_weight = cast(float, self.cfg["negative_weight"])
if neg_weight != 1.0:
d_scores[negative_samples] *= neg_weight
@ -671,12 +673,15 @@ class SpanCategorizer(TrainablePipe):
indices: Ints2d,
scores: Floats2d,
labels: List[str],
# XXX Unused, does it make sense?
allow_overlap: bool = True,
) -> SpanGroup:
# Handle cases when there are zero suggestions
spans = SpanGroup(doc, name=self.key)
max_positive = self.cfg["max_positive"]
if scores.size == 0:
return spans
scores = self.model.ops.to_numpy(scores)
indices = self.model.ops.to_numpy(indices)
threshold = self.cfg["threshold"]
max_positive = self.cfg["max_positive"]
keeps = scores >= threshold
ranked = (scores * -1).argsort() # type: ignore
@ -685,18 +690,19 @@ class SpanCategorizer(TrainablePipe):
span_filter = ranked[:, max_positive:]
for i, row in enumerate(span_filter):
keeps[i, row] = False
# TODO I think this is now incorrect
spans.attrs["scores"] = scores[keeps].flatten()
indices = self.model.ops.to_numpy(indices)
keeps = self.model.ops.to_numpy(keeps)
for i in range(indices.shape[0]):
start = indices[i, 0]
end = indices[i, 1]
for j, keep in enumerate(keeps[i]):
if keep:
# If the predicted label is the negative label skip it.
if self.add_negative_label and labels[j] == self._negative_label:
continue
else:
spans.append(Span(doc, start, end, label=labels[j]))
return spans
@ -709,30 +715,36 @@ class SpanCategorizer(TrainablePipe):
labels: List[str],
allow_overlap: bool = True,
) -> SpanGroup:
# Handle cases when there are zero suggestions
spans = SpanGroup(doc, name=self.key)
if scores.size == 0:
return spans
scores = self.model.ops.to_numpy(scores)
indices = self.model.ops.to_numpy(indices)
# Handle cases when there are zero suggestions
if scores.size == 0:
return SpanGroup(doc, name=self.key)
threshold = self.cfg["threshold"]
predicted = scores.argmax(axis=1)
# Remove samples where the negative label is the argmax
argmax_scores = numpy.take_along_axis(
scores, numpy.expand_dims(predicted, 1), axis=1
).squeeze()
# Remove samples where the negative label is the argmax.
if self.add_negative_label:
positive = numpy.where(predicted != self._negative_label)[0]
predicted = predicted[positive]
indices = indices[positive]
# Filter samples according to threshold.
if threshold is not None:
keeps = numpy.where(argmax_scores >= threshold)
predicted = predicted[keeps]
indices = indices[keeps]
# Sort spans according to argmax probability
if not allow_overlap and predicted.size != 0:
if not allow_overlap:
# Get the probabilities
argmax_probs = numpy.take_along_axis(
scores[positive], numpy.expand_dims(predicted, 1), axis=1
)
argmax_probs = argmax_probs.squeeze()
sort_idx = (argmax_probs * -1).argsort()
sort_idx = (argmax_scores * -1).argsort()
predicted = predicted[sort_idx]
indices = indices[sort_idx]
seen = Ranges()
# TODO assigns spans.attrs["scores"]
seen = Intervals()
spans = SpanGroup(doc, name=self.key)
for i in range(len(predicted)):
label = predicted[i]