mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-05 21:53:05 +03:00
Fix mypy errors
However, I ignored line 370 because it opened up a bunch of type errors that might be trickier to solve and might lead to a more complicated codebase.
This commit is contained in:
parent
dbfb3a7739
commit
2b7eb85e36
|
@ -268,7 +268,7 @@ class SpanCategorizerExclusive(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/spancategorizerexclusive#set_annotations
|
||||
"""
|
||||
allow_overlap = self.cfg["allow_overlap"]
|
||||
allow_overlap = cast(bool, self.cfg["allow_overlap"])
|
||||
labels = self.labels
|
||||
indices, scores = indices_scores
|
||||
offset = 0
|
||||
|
@ -278,9 +278,9 @@ class SpanCategorizerExclusive(TrainablePipe):
|
|||
doc,
|
||||
indices_i,
|
||||
scores[offset : offset + indices.lengths[i]],
|
||||
labels,
|
||||
labels, # type: ignore[arg-type]
|
||||
allow_overlap,
|
||||
) # type: ignore[arg-type]
|
||||
)
|
||||
offset += indices.lengths[i]
|
||||
|
||||
def update(
|
||||
|
@ -367,9 +367,9 @@ class SpanCategorizerExclusive(TrainablePipe):
|
|||
offset += spans.lengths[i]
|
||||
target = self.model.ops.asarray(target, dtype="f") # type: ignore
|
||||
negative_samples = numpy.nonzero(negative_spans)[0]
|
||||
target[negative_samples, self._negative_label] = 1.0
|
||||
target[negative_samples, self._negative_label] = 1.0 # type: ignore
|
||||
d_scores = scores - target
|
||||
neg_weight = self.cfg["negative_weight"]
|
||||
neg_weight = cast(float, self.cfg["negative_weight"])
|
||||
d_scores[negative_samples] *= neg_weight
|
||||
loss = float((d_scores**2).sum())
|
||||
return loss, d_scores
|
||||
|
|
Loading…
Reference in New Issue
Block a user