only restore entities after loss calculation

This commit is contained in:
svlandeg 2024-03-27 16:36:55 +01:00
parent 2b0a3e2e40
commit 42188b30c5

View File

@ -375,10 +375,6 @@ class EntityLinker(TrainablePipe):
sentence_encodings, bp_context = self.model.begin_update(docs)
# now restore the ents
for doc, old in zip(docs, old_ents):
doc.ents = old
loss, d_scores = self.get_loss(
sentence_encodings=sentence_encodings, examples=examples
)
@ -386,6 +382,12 @@ class EntityLinker(TrainablePipe):
if sgd is not None:
self.finish_update(sgd)
losses[self.name] += loss
# now restore the ents
assert len(docs) == len(old_ents)
for doc, old in zip(docs, old_ents):
doc.ents = old
return losses
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):