restore entities of sample in initialization

This commit is contained in:
svlandeg 2024-03-27 16:37:45 +01:00
parent 42188b30c5
commit d41d875186

View File

@ -284,9 +284,11 @@ class EntityLinker(TrainablePipe):
nO = self.kb.entity_vector_length nO = self.kb.entity_vector_length
doc_sample = [] doc_sample = []
vector_sample = [] vector_sample = []
orig_ents = []
for eg in islice(get_examples(), 10): for eg in islice(get_examples(), 10):
doc = eg.x doc = eg.x
if self.use_gold_ents: if self.use_gold_ents:
orig_ents.append(doc.ents)
ents, _ = eg.get_aligned_ents_and_ner() ents, _ = eg.get_aligned_ents_and_ner()
doc.ents = ents doc.ents = ents
doc_sample.append(doc) doc_sample.append(doc)
@ -313,6 +315,10 @@ class EntityLinker(TrainablePipe):
if not has_annotations: if not has_annotations:
# Clean up dummy annotation # Clean up dummy annotation
doc.ents = [] doc.ents = []
if self.use_gold_ents:
assert len(doc_sample) == len(orig_ents)
for doc, orig_ent in zip(doc_sample, orig_ents):
doc.ents = orig_ent
def batch_has_learnable_example(self, examples): def batch_has_learnable_example(self, examples):
"""Check if a batch contains a learnable example. """Check if a batch contains a learnable example.