Fix black and mypy issues

This commit is contained in:
thomashacker 2023-03-31 12:00:14 +02:00
parent b6ad7e6d9b
commit 34d3720fc0
3 changed files with 5 additions and 5 deletions

View File

@ -295,7 +295,7 @@ class TextCategorizer(TrainablePipe):
""" """
if losses is None: if losses is None:
losses = {} losses = {}
losses.setdefault(self.name+"_rehearse", 0.0) losses.setdefault(self.name + "_rehearse", 0.0)
if self._rehearsal_model is None: if self._rehearsal_model is None:
return losses return losses
validate_examples(examples, "TextCategorizer.rehearse") validate_examples(examples, "TextCategorizer.rehearse")
@ -311,7 +311,7 @@ class TextCategorizer(TrainablePipe):
bp_scores(gradient) bp_scores(gradient)
if sgd is not None: if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)
losses[self.name+"_rehearse"] += (gradient**2).sum() losses[self.name + "_rehearse"] += (gradient**2).sum()
return losses return losses
def _examples_to_truth( def _examples_to_truth(

View File

@ -2,7 +2,7 @@ from typing import Union, Iterable, Sequence, TypeVar, List, Callable, Iterator
from typing import Optional, Any from typing import Optional, Any
from functools import partial from functools import partial
import itertools import itertools
from thinc.schedules import Schedule from thinc.schedules import Schedule #type:ignore[attr-defined]
from ..util import registry, minibatch from ..util import registry, minibatch
@ -221,7 +221,7 @@ def _batch_by_length(
if not batch: if not batch:
batch.append(i) batch.append(i)
elif length * (len(batch) + 1) <= max_words: elif length * (len(batch) + 1) <= max_words:
batch.append(i) batch.append(i)
else: else:
batches.append(batch) batches.append(batch)
batch = [i] batch = [i]

View File

@ -241,7 +241,7 @@ def train_while_improving(
score, other_scores = evaluate() score, other_scores = evaluate()
else: else:
score, other_scores = evaluate() score, other_scores = evaluate()
optimizer.last_score = score optimizer.last_score = score #type:ignore[attr-defined]
results.append((score, step)) results.append((score, step))
is_best_checkpoint = score == max(results)[0] is_best_checkpoint = score == max(results)[0]
else: else: