mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-24 19:11:58 +03:00
Cache the label map
This commit is contained in:
parent
2bbab641e9
commit
dbfb3a7739
|
@ -215,17 +215,18 @@ class SpanCategorizerExclusive(TrainablePipe):
|
|||
return list(self.labels)
|
||||
|
||||
@property
|
||||
def _negative_label(self):
|
||||
"""
|
||||
Index of the negative label.
|
||||
"""
|
||||
def label_map(self) -> Dict[str, int]:
|
||||
"""RETURNS (Dict[str, int]): The label map."""
|
||||
return {label: i for i, label in enumerate(self.labels)}
|
||||
|
||||
@property
|
||||
def _negative_label(self) -> int:
|
||||
"""RETURNS (int): Index of the negative label."""
|
||||
return len(self.label_data)
|
||||
|
||||
@property
|
||||
def _n_labels(self):
|
||||
"""
|
||||
Number of labels including the negative label.
|
||||
"""
|
||||
def _n_labels(self) -> int:
|
||||
"""RETURNS (int): Number of labels including the negative label."""
|
||||
return len(self.label_data) + 1
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
|
@ -339,7 +340,6 @@ class SpanCategorizerExclusive(TrainablePipe):
|
|||
spans = Ragged(
|
||||
self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths)
|
||||
)
|
||||
label_map = {label: i for i, label in enumerate(self.labels)}
|
||||
target = numpy.zeros(scores.shape, dtype=scores.dtype)
|
||||
# Set negative class as target initially for all samples.
|
||||
negative_spans = numpy.ones((scores.shape[0]))
|
||||
|
@ -358,7 +358,7 @@ class SpanCategorizerExclusive(TrainablePipe):
|
|||
key = (gold_span.start, gold_span.end)
|
||||
if key in spans_index:
|
||||
row = spans_index[key]
|
||||
k = label_map[gold_span.label_]
|
||||
k = self.label_map[gold_span.label_]
|
||||
target[row, k] = 1.0
|
||||
# delete negative label target.
|
||||
negative_spans[row] = 0.0
|
||||
|
@ -438,14 +438,15 @@ class SpanCategorizerExclusive(TrainablePipe):
|
|||
predicted = scores.argmax(axis=1)
|
||||
|
||||
# Remove samples where the negative label is the argmax
|
||||
positive = numpy.where(predicted != self._negative_label)
|
||||
predicted = predicted[positive[0]]
|
||||
indices = indices[positive[0]]
|
||||
positive = numpy.where(predicted != self._negative_label)[0]
|
||||
predicted = predicted[positive]
|
||||
indices = indices[positive]
|
||||
|
||||
# Sort spans according to argmax probability
|
||||
if not allow_overlap:
|
||||
# Get the probabilities
|
||||
argmax_probs = numpy.take_along_axis(
|
||||
scores[positive[0]], numpy.expand_dims(predicted, 1), axis=1
|
||||
scores[positive], numpy.expand_dims(predicted, 1), axis=1
|
||||
)
|
||||
argmax_probs = argmax_probs.squeeze()
|
||||
sort_idx = (argmax_probs * -1).argsort()
|
||||
|
|
Loading…
Reference in New Issue
Block a user