From b828954fb64abd16c1d7462a1ef11a2d3f7aaa16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 8 Sep 2022 10:31:55 +0200 Subject: [PATCH] Replace "kb_ids" by a constant --- spacy/pipeline/entity_linker.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index e69fa618b..ac05cb840 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -26,6 +26,8 @@ from ..scorer import Scorer ActivationsT = Dict[str, Union[List[Ragged], List[str]]] +KNOWLEDGE_BASE_IDS = "kb_ids" + # See #9050 BACKWARD_OVERWRITE = True @@ -426,7 +428,7 @@ class EntityLinker(TrainablePipe): docs_ents: List[Ragged] = [] docs_scores: List[Ragged] = [] if not docs: - return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores} + return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores} if isinstance(docs, Doc): docs = [docs] for doc in docs: @@ -532,7 +534,7 @@ class EntityLinker(TrainablePipe): method="predict", msg="result variables not of equal length" ) raise RuntimeError(err) - return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores} + return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores} def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None: """Modify a batch of documents, using pre-computed scores. @@ -543,7 +545,7 @@ class EntityLinker(TrainablePipe): DOCS: https://spacy.io/api/entitylinker#set_annotations """ - kb_ids = cast(List[str], activations["kb_ids"]) + kb_ids = cast(List[str], activations[KNOWLEDGE_BASE_IDS]) count_ents = len([ent for doc in docs for ent in doc.ents]) if count_ents != len(kb_ids): raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids))) @@ -553,7 +555,7 @@ class EntityLinker(TrainablePipe): if self.save_activations: doc.activations[self.name] = {} for act_name, acts in activations.items(): - if act_name != "kb_ids": + if act_name != KNOWLEDGE_BASE_IDS: # We only copy activations that are Ragged. doc.activations[self.name][act_name] = cast(Ragged, acts[j])