From 44e14ccae87d4077cfc3b730e76ab32bbb15cafb Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 14 Oct 2020 15:11:34 +0200 Subject: [PATCH] one more losses fix --- spacy/pipeline/tagger.pyx | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 3be93c32c..16633a7b8 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -227,10 +227,13 @@ class Tagger(TrainablePipe): DOCS: https://nightly.spacy.io/api/tagger#rehearse """ + if losses is None: + losses = {} + losses.setdefault(self.name, 0.0) validate_examples(examples, "Tagger.rehearse") docs = [eg.predicted for eg in examples] if self._rehearsal_model is None: - return + return losses if not any(len(doc) for doc in docs): # Handle cases where there are no tokens in any docs. return losses @@ -240,9 +243,7 @@ class Tagger(TrainablePipe): gradient = guesses - target backprop(gradient) self.finish_update(sgd) - if losses is not None: - losses.setdefault(self.name, 0.0) - losses[self.name] += (gradient**2).sum() + losses[self.name] += (gradient**2).sum() return losses def get_loss(self, examples, scores):