Cache the label map

This commit is contained in:
Lj Miranda 2022-09-05 14:34:49 +08:00
parent 2bbab641e9
commit dbfb3a7739

View File

@ -215,17 +215,18 @@ class SpanCategorizerExclusive(TrainablePipe):
return list(self.labels)
@property
def _negative_label(self):
"""
Index of the negative label.
"""
def label_map(self) -> Dict[str, int]:
"""RETURNS (Dict[str, int]): The label map."""
return {label: i for i, label in enumerate(self.labels)}
@property
def _negative_label(self) -> int:
"""RETURNS (int): Index of the negative label."""
return len(self.label_data)
@property
def _n_labels(self):
"""
Number of labels including the negative label.
"""
def _n_labels(self) -> int:
"""RETURNS (int): Number of labels including the negative label."""
return len(self.label_data) + 1
def predict(self, docs: Iterable[Doc]):
@ -339,7 +340,6 @@ class SpanCategorizerExclusive(TrainablePipe):
spans = Ragged(
self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths)
)
label_map = {label: i for i, label in enumerate(self.labels)}
target = numpy.zeros(scores.shape, dtype=scores.dtype)
# Set negative class as target initially for all samples.
negative_spans = numpy.ones((scores.shape[0]))
@ -358,7 +358,7 @@ class SpanCategorizerExclusive(TrainablePipe):
key = (gold_span.start, gold_span.end)
if key in spans_index:
row = spans_index[key]
k = label_map[gold_span.label_]
k = self.label_map[gold_span.label_]
target[row, k] = 1.0
# delete negative label target.
negative_spans[row] = 0.0
@ -438,14 +438,15 @@ class SpanCategorizerExclusive(TrainablePipe):
predicted = scores.argmax(axis=1)
# Remove samples where the negative label is the argmax
positive = numpy.where(predicted != self._negative_label)
predicted = predicted[positive[0]]
indices = indices[positive[0]]
positive = numpy.where(predicted != self._negative_label)[0]
predicted = predicted[positive]
indices = indices[positive]
# Sort spans according to argmax probability
if not allow_overlap:
# Get the probabilities
argmax_probs = numpy.take_along_axis(
scores[positive[0]], numpy.expand_dims(predicted, 1), axis=1
scores[positive], numpy.expand_dims(predicted, 1), axis=1
)
argmax_probs = argmax_probs.squeeze()
sort_idx = (argmax_probs * -1).argsort()