Cache the label map

This commit is contained in:
Lj Miranda 2022-09-05 14:34:49 +08:00
parent 2bbab641e9
commit dbfb3a7739

View File

@ -215,17 +215,18 @@ class SpanCategorizerExclusive(TrainablePipe):
return list(self.labels) return list(self.labels)
@property @property
def _negative_label(self): def label_map(self) -> Dict[str, int]:
""" """RETURNS (Dict[str, int]): The label map."""
Index of the negative label. return {label: i for i, label in enumerate(self.labels)}
"""
@property
def _negative_label(self) -> int:
"""RETURNS (int): Index of the negative label."""
return len(self.label_data) return len(self.label_data)
@property @property
def _n_labels(self): def _n_labels(self) -> int:
""" """RETURNS (int): Number of labels including the negative label."""
Number of labels including the negative label.
"""
return len(self.label_data) + 1 return len(self.label_data) + 1
def predict(self, docs: Iterable[Doc]): def predict(self, docs: Iterable[Doc]):
@ -339,7 +340,6 @@ class SpanCategorizerExclusive(TrainablePipe):
spans = Ragged( spans = Ragged(
self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths) self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths)
) )
label_map = {label: i for i, label in enumerate(self.labels)}
target = numpy.zeros(scores.shape, dtype=scores.dtype) target = numpy.zeros(scores.shape, dtype=scores.dtype)
# Set negative class as target initially for all samples. # Set negative class as target initially for all samples.
negative_spans = numpy.ones((scores.shape[0])) negative_spans = numpy.ones((scores.shape[0]))
@ -358,7 +358,7 @@ class SpanCategorizerExclusive(TrainablePipe):
key = (gold_span.start, gold_span.end) key = (gold_span.start, gold_span.end)
if key in spans_index: if key in spans_index:
row = spans_index[key] row = spans_index[key]
k = label_map[gold_span.label_] k = self.label_map[gold_span.label_]
target[row, k] = 1.0 target[row, k] = 1.0
# delete negative label target. # delete negative label target.
negative_spans[row] = 0.0 negative_spans[row] = 0.0
@ -438,14 +438,15 @@ class SpanCategorizerExclusive(TrainablePipe):
predicted = scores.argmax(axis=1) predicted = scores.argmax(axis=1)
# Remove samples where the negative label is the argmax # Remove samples where the negative label is the argmax
positive = numpy.where(predicted != self._negative_label) positive = numpy.where(predicted != self._negative_label)[0]
predicted = predicted[positive[0]] predicted = predicted[positive]
indices = indices[positive[0]] indices = indices[positive]
# Sort spans according to argmax probability # Sort spans according to argmax probability
if not allow_overlap: if not allow_overlap:
# Get the probabilities
argmax_probs = numpy.take_along_axis( argmax_probs = numpy.take_along_axis(
scores[positive[0]], numpy.expand_dims(predicted, 1), axis=1 scores[positive], numpy.expand_dims(predicted, 1), axis=1
) )
argmax_probs = argmax_probs.squeeze() argmax_probs = argmax_probs.squeeze()
sort_idx = (argmax_probs * -1).argsort() sort_idx = (argmax_probs * -1).argsort()