mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
Replace "kb_ids" by a constant
This commit is contained in:
parent
ac5b1fd264
commit
b828954fb6
|
@ -26,6 +26,8 @@ from ..scorer import Scorer
|
||||||
|
|
||||||
ActivationsT = Dict[str, Union[List[Ragged], List[str]]]
|
ActivationsT = Dict[str, Union[List[Ragged], List[str]]]
|
||||||
|
|
||||||
|
KNOWLEDGE_BASE_IDS = "kb_ids"
|
||||||
|
|
||||||
# See #9050
|
# See #9050
|
||||||
BACKWARD_OVERWRITE = True
|
BACKWARD_OVERWRITE = True
|
||||||
|
|
||||||
|
@ -426,7 +428,7 @@ class EntityLinker(TrainablePipe):
|
||||||
docs_ents: List[Ragged] = []
|
docs_ents: List[Ragged] = []
|
||||||
docs_scores: List[Ragged] = []
|
docs_scores: List[Ragged] = []
|
||||||
if not docs:
|
if not docs:
|
||||||
return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}
|
return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
@ -532,7 +534,7 @@ class EntityLinker(TrainablePipe):
|
||||||
method="predict", msg="result variables not of equal length"
|
method="predict", msg="result variables not of equal length"
|
||||||
)
|
)
|
||||||
raise RuntimeError(err)
|
raise RuntimeError(err)
|
||||||
return {"kb_ids": final_kb_ids, "ents": docs_ents, "scores": docs_scores}
|
return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}
|
||||||
|
|
||||||
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
|
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
|
||||||
"""Modify a batch of documents, using pre-computed scores.
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
|
@ -543,7 +545,7 @@ class EntityLinker(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entitylinker#set_annotations
|
DOCS: https://spacy.io/api/entitylinker#set_annotations
|
||||||
"""
|
"""
|
||||||
kb_ids = cast(List[str], activations["kb_ids"])
|
kb_ids = cast(List[str], activations[KNOWLEDGE_BASE_IDS])
|
||||||
count_ents = len([ent for doc in docs for ent in doc.ents])
|
count_ents = len([ent for doc in docs for ent in doc.ents])
|
||||||
if count_ents != len(kb_ids):
|
if count_ents != len(kb_ids):
|
||||||
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
||||||
|
@ -553,7 +555,7 @@ class EntityLinker(TrainablePipe):
|
||||||
if self.save_activations:
|
if self.save_activations:
|
||||||
doc.activations[self.name] = {}
|
doc.activations[self.name] = {}
|
||||||
for act_name, acts in activations.items():
|
for act_name, acts in activations.items():
|
||||||
if act_name != "kb_ids":
|
if act_name != KNOWLEDGE_BASE_IDS:
|
||||||
# We only copy activations that are Ragged.
|
# We only copy activations that are Ragged.
|
||||||
doc.activations[self.name][act_name] = cast(Ragged, acts[j])
|
doc.activations[self.name][act_name] = cast(Ragged, acts[j])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user