mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
ensure the loss value is cast as float (#6928)
This commit is contained in:
parent
a7977b5143
commit
a323ef90df
|
@ -273,7 +273,7 @@ class EntityLinker(TrainablePipe):
|
|||
gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
|
||||
loss = self.distance.get_loss(sentence_encodings, entity_encodings)
|
||||
loss = loss / len(entity_encodings)
|
||||
return loss, gradients
|
||||
return float(loss), gradients
|
||||
|
||||
def predict(self, docs: Iterable[Doc]) -> List[str]:
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
|
|
@ -197,7 +197,7 @@ class ClozeMultitask(TrainablePipe):
|
|||
target = vectors[ids]
|
||||
gradient = self.distance.get_grad(prediction, target)
|
||||
loss = self.distance.get_loss(prediction, target)
|
||||
return loss, gradient
|
||||
return float(loss), gradient
|
||||
|
||||
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue
Block a user