one more losses fix

This commit is contained in:
svlandeg 2020-10-14 15:11:34 +02:00
parent 478a14a619
commit 44e14ccae8

View File

@ -227,10 +227,13 @@ class Tagger(TrainablePipe):
DOCS: https://nightly.spacy.io/api/tagger#rehearse DOCS: https://nightly.spacy.io/api/tagger#rehearse
""" """
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "Tagger.rehearse") validate_examples(examples, "Tagger.rehearse")
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
if self._rehearsal_model is None: if self._rehearsal_model is None:
return return losses
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return losses return losses
@ -240,9 +243,7 @@ class Tagger(TrainablePipe):
gradient = guesses - target gradient = guesses - target
backprop(gradient) backprop(gradient)
self.finish_update(sgd) self.finish_update(sgd)
if losses is not None: losses[self.name] += (gradient**2).sum()
losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient**2).sum()
return losses return losses
def get_loss(self, examples, scores): def get_loss(self, examples, scores):