always return losses

This commit is contained in:
svlandeg 2020-10-14 15:00:49 +02:00
parent 1f49300862
commit 0aa8851878
2 changed files with 4 additions and 3 deletions

View File

@ -195,7 +195,7 @@ class Tagger(TrainablePipe):
validate_examples(examples, "Tagger.update") validate_examples(examples, "Tagger.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples]) tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples])
for sc in tag_scores: for sc in tag_scores:
@ -233,7 +233,7 @@ class Tagger(TrainablePipe):
return return
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
guesses, backprop = self.model.begin_update(docs) guesses, backprop = self.model.begin_update(docs)
target = self._rehearsal_model(examples) target = self._rehearsal_model(examples)
@ -243,6 +243,7 @@ class Tagger(TrainablePipe):
if losses is not None: if losses is not None:
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient**2).sum() losses[self.name] += (gradient**2).sum()
return losses
def get_loss(self, examples, scores): def get_loss(self, examples, scores):
"""Find the loss and gradient of loss for the batch of documents and """Find the loss and gradient of loss for the batch of documents and

View File

@ -116,7 +116,7 @@ cdef class TrainablePipe(Pipe):
validate_examples(examples, "TrainablePipe.update") validate_examples(examples, "TrainablePipe.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples]) scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples])
loss, d_scores = self.get_loss(examples, scores) loss, d_scores = self.get_loss(examples, scores)