diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 650a01949..1364c1ae1 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -288,11 +288,14 @@ class TextCategorizer(TrainablePipe): set_dropout_rate(self.model, drop) scores, bp_scores = self.model.begin_update(docs) target, _ = self._rehearsal_model.begin_update(docs) - gradient = scores - target - bp_scores(gradient) + teacher_loss, teacher_scores = self.get_rehearse_loss(target, scores) + truth_loss, truth_scores = self.get_loss(examples, scores) + d_scores = (teacher_scores + truth_scores) / 2 + loss = teacher_loss + truth_loss + bp_scores(d_scores) if sgd is not None: self.finish_update(sgd) - losses[self.name] += (gradient**2).sum() + losses[self.name] += float(loss) return losses def _examples_to_truth(