ensure the loss value is cast as float (#6928)

This commit is contained in:
Sofie Van Landeghem 2021-02-07 00:51:56 +01:00 committed by GitHub
parent a7977b5143
commit a323ef90df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -273,7 +273,7 @@ class EntityLinker(TrainablePipe):
gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
loss = self.distance.get_loss(sentence_encodings, entity_encodings)
loss = loss / len(entity_encodings)
return loss, gradients
return float(loss), gradients
def predict(self, docs: Iterable[Doc]) -> List[str]:
"""Apply the pipeline's model to a batch of docs, without modifying them.

View File

@ -197,7 +197,7 @@ class ClozeMultitask(TrainablePipe):
target = vectors[ids]
gradient = self.distance.get_grad(prediction, target)
loss = self.distance.get_loss(prediction, target)
return loss, gradient
return float(loss), gradient
def update(self, examples, *, drop=0., sgd=None, losses=None):
pass