mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 02:36:32 +03:00
ensure the loss value is cast as float (#6928)
This commit is contained in:
parent
a7977b5143
commit
a323ef90df
|
@ -273,7 +273,7 @@ class EntityLinker(TrainablePipe):
|
||||||
gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
|
gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
|
||||||
loss = self.distance.get_loss(sentence_encodings, entity_encodings)
|
loss = self.distance.get_loss(sentence_encodings, entity_encodings)
|
||||||
loss = loss / len(entity_encodings)
|
loss = loss / len(entity_encodings)
|
||||||
return loss, gradients
|
return float(loss), gradients
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]) -> List[str]:
|
def predict(self, docs: Iterable[Doc]) -> List[str]:
|
||||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
|
|
@ -197,7 +197,7 @@ class ClozeMultitask(TrainablePipe):
|
||||||
target = vectors[ids]
|
target = vectors[ids]
|
||||||
gradient = self.distance.get_grad(prediction, target)
|
gradient = self.distance.get_grad(prediction, target)
|
||||||
loss = self.distance.get_loss(prediction, target)
|
loss = self.distance.get_loss(prediction, target)
|
||||||
return loss, gradient
|
return float(loss), gradient
|
||||||
|
|
||||||
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue
Block a user