avoid multiplication with 1.0

Co-authored-by: kadarakos <kadar.akos@gmail.com>
This commit is contained in:
Sofie Van Landeghem 2022-10-03 17:05:55 +02:00 committed by GitHub
parent 2b7eb85e36
commit 95c5bfcc78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -370,7 +370,8 @@ class SpanCategorizerExclusive(TrainablePipe):
target[negative_samples, self._negative_label] = 1.0 # type: ignore
d_scores = scores - target
neg_weight = cast(float, self.cfg["negative_weight"])
d_scores[negative_samples] *= neg_weight
if neg_weight != 1.0:
d_scores[negative_samples] *= neg_weight
loss = float((d_scores**2).sum())
return loss, d_scores