Formatting

This commit is contained in:
Paul O'Leary McCann 2022-11-14 19:31:50 +09:00
parent f51a63863d
commit 970ca8a4f1

View File

@ -293,7 +293,7 @@ class TextCategorizer(TrainablePipe):
bp_scores(gradient) bp_scores(gradient)
if sgd is not None: if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)
losses[self.name] += (gradient**2).sum() losses[self.name] += (gradient ** 2).sum()
return losses return losses
def _examples_to_truth( def _examples_to_truth(
@ -327,7 +327,7 @@ class TextCategorizer(TrainablePipe):
not_missing = self.model.ops.asarray(not_missing) # type: ignore not_missing = self.model.ops.asarray(not_missing) # type: ignore
d_scores = scores - truths d_scores = scores - truths
d_scores *= not_missing d_scores *= not_missing
mean_square_error = (d_scores**2).mean() mean_square_error = (d_scores ** 2).mean()
return float(mean_square_error), d_scores return float(mean_square_error), d_scores
def add_label(self, label: str) -> int: def add_label(self, label: str) -> int: