replace single_label with add_negative_label and adjust inference

This commit is contained in:
kadarakos 2023-02-06 18:54:30 +00:00
parent c864f12e28
commit c24b3785a6

View File

@ -3,6 +3,7 @@ from dataclasses import dataclass
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer from thinc.api import Optimizer
from thinc.types import Ragged, Ints2d, Floats2d from thinc.types import Ragged, Ints2d, Floats2d
from wasabi import msg
import numpy import numpy
@ -179,13 +180,13 @@ def make_spancat(
model=model, model=model,
suggester=suggester, suggester=suggester,
name=name, name=name,
single_label=False,
spans_key=spans_key, spans_key=spans_key,
negative_weight=None, negative_weight=None,
allow_overlap=None, allow_overlap=None,
max_positive=max_positive, max_positive=max_positive,
threshold=threshold, threshold=threshold,
scorer=scorer, scorer=scorer,
add_negative_label=False,
) )
@ -242,11 +243,11 @@ def make_spancat_singlelabel(
model=model, model=model,
suggester=suggester, suggester=suggester,
name=name, name=name,
single_label=True,
spans_key=spans_key, spans_key=spans_key,
negative_weight=negative_weight, negative_weight=negative_weight,
allow_overlap=allow_overlap, allow_overlap=allow_overlap,
max_positive=None, max_positive=1,
add_negative_label=True,
threshold=None, threshold=None,
scorer=scorer, scorer=scorer,
) )
@ -271,7 +272,7 @@ def make_spancat_scorer():
@dataclass @dataclass
class Ranges: class Intervals:
""" """
Helper class to avoid storing overlapping spans. Helper class to avoid storing overlapping spans.
""" """
@ -306,7 +307,7 @@ class SpanCategorizer(TrainablePipe):
# and spancat_singlelabel # and spancat_singlelabel
name: str = "spancat", name: str = "spancat",
*, *,
single_label: bool = False, add_negative_label: bool = False,
spans_key: str = "spans", spans_key: str = "spans",
negative_weight: Optional[float] = None, negative_weight: Optional[float] = None,
allow_overlap: Optional[bool] = None, allow_overlap: Optional[bool] = None,
@ -331,14 +332,11 @@ class SpanCategorizer(TrainablePipe):
During initialization and training, the component will look for During initialization and training, the component will look for
spans on the reference document under the same key. Defaults to spans on the reference document under the same key. Defaults to
`"spans"`. `"spans"`.
single_label (bool): Whether to configure SpanCategorizer to produce add_negative_label (bool): Learn to predict a special 'negative_label'
a single label per span. In this case its expected that the scorer when a Span is not annotated.
layer is Softmax. Otherwise its expected to be Logistic. When single_label
is true the SpanCategorizer internally has a negative-label indicating
that a span should not receive any of the labels found in the corpus.
threshold (Optional[float]): Minimum probability to consider a prediction threshold (Optional[float]): Minimum probability to consider a prediction
positive in the multilabel usecase.Defaults to 0.5 when single_label is positive. Defaults to 0.5. Spans with a positive prediction will be saved
False otherwise its None. Spans with a positive prediction will be saved on the Doc. on the Doc.
max_positive (Optional[int]): Maximum number of labels to consider max_positive (Optional[int]): Maximum number of labels to consider
positive per span. Defaults to None, indicating no limit. This is positive per span. Defaults to None, indicating no limit. This is
unused when single_label is True. unused when single_label is True.
@ -367,7 +365,16 @@ class SpanCategorizer(TrainablePipe):
self.model = model self.model = model
self.name = name self.name = name
self.scorer = scorer self.scorer = scorer
self.single_label = single_label self.add_negative_label = add_negative_label
if allow_overlap is None:
self.cfg["allow_overlap"] = True
elif not allow_overlap and max_positive > 1:
self.cfg["allow_overlap"] = True
msg.warn(
"'allow_overlap' can only be False when max_positive=1, "
f"but found 'max_positive': {max_positive} "
"SpanCategorizer is automatically configured with allow_overlap=True."
)
@property @property
def key(self) -> str: def key(self) -> str:
@ -433,7 +440,7 @@ class SpanCategorizer(TrainablePipe):
@property @property
def _n_labels(self) -> int: def _n_labels(self) -> int:
"""RETURNS (int): Number of labels.""" """RETURNS (int): Number of labels."""
if self.single_label: if self.add_negative_label:
return len(self.labels) + 1 return len(self.labels) + 1
else: else:
return len(self.labels) return len(self.labels)
@ -441,7 +448,7 @@ class SpanCategorizer(TrainablePipe):
@property @property
def _negative_label(self) -> Union[int, None]: def _negative_label(self) -> Union[int, None]:
"""RETURNS (Union[int, None]): Index of the negative label.""" """RETURNS (Union[int, None]): Index of the negative label."""
if self.single_label: if self.add_negative_label:
return len(self.label_data) return len(self.label_data)
else: else:
return None return None
@ -492,13 +499,9 @@ class SpanCategorizer(TrainablePipe):
offset = 0 offset = 0
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
indices_i = indices[i].dataXd indices_i = indices[i].dataXd
if self.single_label: allow_overlap = self.cfg["allow_overlap"]
allow_overlap = self.cfg["allow_overlap"] allow_overlap = cast(bool, self.cfg["allow_overlap"])
# Interpret None as False if allow_overlap is not provided if self.cfg["max_positive"] == 1:
if allow_overlap is None:
allow_overlap = False
else:
allow_overlap = cast(bool, self.cfg["allow_overlap"])
doc.spans[self.key] = self._make_span_group_singlelabel( doc.spans[self.key] = self._make_span_group_singlelabel(
doc, doc,
indices_i, indices_i,
@ -513,7 +516,6 @@ class SpanCategorizer(TrainablePipe):
scores[offset : offset + indices.lengths[i]], scores[offset : offset + indices.lengths[i]],
labels, # type: ignore[arg-type] labels, # type: ignore[arg-type]
) )
offset += indices.lengths[i] offset += indices.lengths[i]
def update( def update(
@ -574,7 +576,7 @@ class SpanCategorizer(TrainablePipe):
self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths) self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths)
) )
target = numpy.zeros(scores.shape, dtype=scores.dtype) target = numpy.zeros(scores.shape, dtype=scores.dtype)
if self.single_label: if self.add_negative_label:
negative_spans = numpy.ones((scores.shape[0])) negative_spans = numpy.ones((scores.shape[0]))
offset = 0 offset = 0
for i, eg in enumerate(examples): for i, eg in enumerate(examples):
@ -593,14 +595,14 @@ class SpanCategorizer(TrainablePipe):
row = spans_index[key] row = spans_index[key]
k = self.label_map[gold_span.label_] k = self.label_map[gold_span.label_]
target[row, k] = 1.0 target[row, k] = 1.0
if self.single_label: if self.add_negative_label:
# delete negative label target. # delete negative label target.
negative_spans[row] = 0.0 negative_spans[row] = 0.0
# The target is a flat array for all docs. Track the position # The target is a flat array for all docs. Track the position
# we're at within the flat array. # we're at within the flat array.
offset += spans.lengths[i] offset += spans.lengths[i]
target = self.model.ops.asarray(target, dtype="f") # type: ignore target = self.model.ops.asarray(target, dtype="f") # type: ignore
if self.single_label: if self.add_negative_label:
negative_samples = numpy.nonzero(negative_spans)[0] negative_samples = numpy.nonzero(negative_spans)[0]
target[negative_samples, self._negative_label] = 1.0 # type: ignore target[negative_samples, self._negative_label] = 1.0 # type: ignore
# The target will have the values 0 (for untrue predictions) or 1 # The target will have the values 0 (for untrue predictions) or 1
@ -611,7 +613,7 @@ class SpanCategorizer(TrainablePipe):
# If the prediction is 0.9 and it's false, the gradient will be # If the prediction is 0.9 and it's false, the gradient will be
# 0.9 (0.9 - 0.0) # 0.9 (0.9 - 0.0)
d_scores = scores - target d_scores = scores - target
if self.single_label: if self.add_negative_label:
neg_weight = cast(float, self.cfg["negative_weight"]) neg_weight = cast(float, self.cfg["negative_weight"])
if neg_weight != 1.0: if neg_weight != 1.0:
d_scores[negative_samples] *= neg_weight d_scores[negative_samples] *= neg_weight
@ -671,12 +673,15 @@ class SpanCategorizer(TrainablePipe):
indices: Ints2d, indices: Ints2d,
scores: Floats2d, scores: Floats2d,
labels: List[str], labels: List[str],
# XXX Unused, does it make sense?
allow_overlap: bool = True,
) -> SpanGroup: ) -> SpanGroup:
# Handle cases when there are zero suggestions
spans = SpanGroup(doc, name=self.key) spans = SpanGroup(doc, name=self.key)
max_positive = self.cfg["max_positive"] if scores.size == 0:
return spans
scores = self.model.ops.to_numpy(scores)
indices = self.model.ops.to_numpy(indices)
threshold = self.cfg["threshold"] threshold = self.cfg["threshold"]
max_positive = self.cfg["max_positive"]
keeps = scores >= threshold keeps = scores >= threshold
ranked = (scores * -1).argsort() # type: ignore ranked = (scores * -1).argsort() # type: ignore
@ -685,19 +690,20 @@ class SpanCategorizer(TrainablePipe):
span_filter = ranked[:, max_positive:] span_filter = ranked[:, max_positive:]
for i, row in enumerate(span_filter): for i, row in enumerate(span_filter):
keeps[i, row] = False keeps[i, row] = False
# TODO I think this is now incorrect
spans.attrs["scores"] = scores[keeps].flatten() spans.attrs["scores"] = scores[keeps].flatten()
indices = self.model.ops.to_numpy(indices)
keeps = self.model.ops.to_numpy(keeps)
for i in range(indices.shape[0]): for i in range(indices.shape[0]):
start = indices[i, 0] start = indices[i, 0]
end = indices[i, 1] end = indices[i, 1]
for j, keep in enumerate(keeps[i]): for j, keep in enumerate(keeps[i]):
if keep: if keep:
spans.append(Span(doc, start, end, label=labels[j])) # If the predicted label is the negative label skip it.
if self.add_negative_label and labels[j] == self._negative_label:
continue
else:
spans.append(Span(doc, start, end, label=labels[j]))
return spans return spans
@ -709,30 +715,36 @@ class SpanCategorizer(TrainablePipe):
labels: List[str], labels: List[str],
allow_overlap: bool = True, allow_overlap: bool = True,
) -> SpanGroup: ) -> SpanGroup:
# Handle cases when there are zero suggestions
spans = SpanGroup(doc, name=self.key)
if scores.size == 0:
return spans
scores = self.model.ops.to_numpy(scores) scores = self.model.ops.to_numpy(scores)
indices = self.model.ops.to_numpy(indices) indices = self.model.ops.to_numpy(indices)
# Handle cases when there are zero suggestions threshold = self.cfg["threshold"]
if scores.size == 0:
return SpanGroup(doc, name=self.key)
predicted = scores.argmax(axis=1) predicted = scores.argmax(axis=1)
# Remove samples where the negative label is the argmax argmax_scores = numpy.take_along_axis(
positive = numpy.where(predicted != self._negative_label)[0] scores, numpy.expand_dims(predicted, 1), axis=1
predicted = predicted[positive] ).squeeze()
indices = indices[positive] # Remove samples where the negative label is the argmax.
if self.add_negative_label:
positive = numpy.where(predicted != self._negative_label)[0]
predicted = predicted[positive]
indices = indices[positive]
# Filter samples according to threshold.
if threshold is not None:
keeps = numpy.where(argmax_scores >= threshold)
predicted = predicted[keeps]
indices = indices[keeps]
# Sort spans according to argmax probability # Sort spans according to argmax probability
if not allow_overlap and predicted.size != 0: if not allow_overlap:
# Get the probabilities # Get the probabilities
argmax_probs = numpy.take_along_axis( sort_idx = (argmax_scores * -1).argsort()
scores[positive], numpy.expand_dims(predicted, 1), axis=1
)
argmax_probs = argmax_probs.squeeze()
sort_idx = (argmax_probs * -1).argsort()
predicted = predicted[sort_idx] predicted = predicted[sort_idx]
indices = indices[sort_idx] indices = indices[sort_idx]
seen = Ranges() # TODO assigns spans.attrs["scores"]
seen = Intervals()
spans = SpanGroup(doc, name=self.key) spans = SpanGroup(doc, name=self.key)
for i in range(len(predicted)): for i in range(len(predicted)):
label = predicted[i] label = predicted[i]