From 42188b30c5dd6cb6ee1dbd909e127fd8f063c4f3 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 27 Mar 2024 16:36:55 +0100 Subject: [PATCH] only restore entities after loss calculation --- spacy/pipeline/entity_linker.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index a730ece1b..3f70475fd 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -375,10 +375,6 @@ class EntityLinker(TrainablePipe): sentence_encodings, bp_context = self.model.begin_update(docs) - # now restore the ents - for doc, old in zip(docs, old_ents): - doc.ents = old - loss, d_scores = self.get_loss( sentence_encodings=sentence_encodings, examples=examples ) @@ -386,6 +382,12 @@ class EntityLinker(TrainablePipe): if sgd is not None: self.finish_update(sgd) losses[self.name] += loss + + # now restore the ents + assert len(docs) == len(old_ents) + for doc, old in zip(docs, old_ents): + doc.ents = old + return losses def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):