From a1cde9d021a94e1005607af9f11f6032fcf481ea Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 27 Mar 2024 18:45:34 +0100 Subject: [PATCH] fix formatting --- spacy/pipeline/entity_linker.py | 16 +++++++++------- spacy/tests/pipeline/test_entity_linker.py | 8 ++++++-- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index d6688002b..790baf766 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -258,14 +258,16 @@ class EntityLinker(TrainablePipe): self.scorer = _score_augmented def _augment_examples(self, examples: Iterable[Example]) -> Iterable[Example]: - """If use_gold_ents is true, set the gold entities to eg.predicted. - """ + """If use_gold_ents is true, set the gold entities to (a copy of) eg.predicted.""" + if not self.use_gold_ents: + return examples + new_examples = [] for eg in examples: - if self.use_gold_ents: - ents, _ = eg.get_aligned_ents_and_ner() - eg.predicted.ents = ents - new_examples.append(eg) + ents, _ = eg.get_aligned_ents_and_ner() + new_eg = eg.copy() + new_eg.predicted.ents = ents + new_examples.append(new_eg) return new_examples def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): @@ -399,7 +401,7 @@ class EntityLinker(TrainablePipe): return losses def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d): - """ Here, we assume that get_loss is called with augmented examples if need be""" + """Here, we assume that get_loss is called with augmented examples if need be""" validate_examples(examples, "EntityLinker.get_loss") entity_encodings = [] eidx = 0 # indices in gold entities to keep diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 8450c5bf4..5e50a4d28 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -744,7 +744,9 @@ def test_overfitting_IO_gold_entities(): return mykb # Create the Entity Linker component and add it to the pipeline - entity_linker = nlp.add_pipe("entity_linker", last=True, config={"use_gold_ents": True}) + entity_linker = nlp.add_pipe( + "entity_linker", last=True, config={"use_gold_ents": True} + ) assert isinstance(entity_linker, EntityLinker) entity_linker.set_kb(create_kb) assert "Q2146908" in entity_linker.vocab.strings @@ -849,7 +851,9 @@ def test_overfitting_IO_with_ner(): # Create the NER and EL components and add them to the pipeline ner = nlp.add_pipe("ner", first=True) - entity_linker = nlp.add_pipe("entity_linker", last=True, config={"use_gold_ents": False}) + entity_linker = nlp.add_pipe( + "entity_linker", last=True, config={"use_gold_ents": False} + ) entity_linker.set_kb(create_kb) train_examples = []