Fix mypy errors

However, I ignored line 370 because it opened up a bunch of type errors
that might be trickier to solve and might lead to a more complicated
codebase.
This commit is contained in:
Lj Miranda 2022-09-05 15:42:34 +08:00
parent dbfb3a7739
commit 2b7eb85e36

View File

@ -268,7 +268,7 @@ class SpanCategorizerExclusive(TrainablePipe):
DOCS: https://spacy.io/api/spancategorizerexclusive#set_annotations
"""
allow_overlap = self.cfg["allow_overlap"]
allow_overlap = cast(bool, self.cfg["allow_overlap"])
labels = self.labels
indices, scores = indices_scores
offset = 0
@ -278,9 +278,9 @@ class SpanCategorizerExclusive(TrainablePipe):
doc,
indices_i,
scores[offset : offset + indices.lengths[i]],
labels,
labels, # type: ignore[arg-type]
allow_overlap,
) # type: ignore[arg-type]
)
offset += indices.lengths[i]
def update(
@ -367,9 +367,9 @@ class SpanCategorizerExclusive(TrainablePipe):
offset += spans.lengths[i]
target = self.model.ops.asarray(target, dtype="f") # type: ignore
negative_samples = numpy.nonzero(negative_spans)[0]
target[negative_samples, self._negative_label] = 1.0
target[negative_samples, self._negative_label] = 1.0 # type: ignore
d_scores = scores - target
neg_weight = self.cfg["negative_weight"]
neg_weight = cast(float, self.cfg["negative_weight"])
d_scores[negative_samples] *= neg_weight
loss = float((d_scores**2).sum())
return loss, d_scores