fix formatting

This commit is contained in:
svlandeg 2024-03-27 18:45:34 +01:00
parent ff88ab341a
commit a1cde9d021
2 changed files with 15 additions and 9 deletions

View File

@ -258,14 +258,16 @@ class EntityLinker(TrainablePipe):
self.scorer = _score_augmented
def _augment_examples(self, examples: Iterable[Example]) -> Iterable[Example]:
"""If use_gold_ents is true, set the gold entities to eg.predicted.
"""
"""If use_gold_ents is true, set the gold entities to (a copy of) eg.predicted."""
if not self.use_gold_ents:
return examples
new_examples = []
for eg in examples:
if self.use_gold_ents:
ents, _ = eg.get_aligned_ents_and_ner()
eg.predicted.ents = ents
new_examples.append(eg)
ents, _ = eg.get_aligned_ents_and_ner()
new_eg = eg.copy()
new_eg.predicted.ents = ents
new_examples.append(new_eg)
return new_examples
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
@ -399,7 +401,7 @@ class EntityLinker(TrainablePipe):
return losses
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
""" Here, we assume that get_loss is called with augmented examples if need be"""
"""Here, we assume that get_loss is called with augmented examples if need be"""
validate_examples(examples, "EntityLinker.get_loss")
entity_encodings = []
eidx = 0 # indices in gold entities to keep

View File

@ -744,7 +744,9 @@ def test_overfitting_IO_gold_entities():
return mykb
# Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe("entity_linker", last=True, config={"use_gold_ents": True})
entity_linker = nlp.add_pipe(
"entity_linker", last=True, config={"use_gold_ents": True}
)
assert isinstance(entity_linker, EntityLinker)
entity_linker.set_kb(create_kb)
assert "Q2146908" in entity_linker.vocab.strings
@ -849,7 +851,9 @@ def test_overfitting_IO_with_ner():
# Create the NER and EL components and add them to the pipeline
ner = nlp.add_pipe("ner", first=True)
entity_linker = nlp.add_pipe("entity_linker", last=True, config={"use_gold_ents": False})
entity_linker = nlp.add_pipe(
"entity_linker", last=True, config={"use_gold_ents": False}
)
entity_linker.set_kb(create_kb)
train_examples = []