diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index 3ee434ab5..fa537edc9 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -18,6 +18,8 @@ cdef class KnowledgeBase: DOCS: https://spacy.io/api/kb """ + _KBType = TypeVar("_KBType", bound=KnowledgeBase) + def __init__(self, vocab: Vocab, entity_vector_length: int): """Create a KnowledgeBase.""" # Make sure abstract KB is not instantiated. @@ -107,16 +109,15 @@ cdef class KnowledgeBase: Errors.E1044.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) ) - KBType = TypeVar("KBType", bound=KnowledgeBase) @classmethod def generate_from_disk( - cls: Type[KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList() - ) -> KBType: + cls: Type[_KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList() + ) -> _KBType: """ - Factory method for generating KnowledgeBase instance from file. + Factory method for generating KnowledgeBase subclass instance from file. path (Union[str, Path]): Target file path. exclude (Iterable[str]): List of components to exclude. - return (KBType): Instance of KnowledgeBase generated from file. + return (_KBType): Instance of KnowledgeBase subclass generated from file. """ raise NotImplementedError( Errors.E1044.format(parent="KnowledgeBase", method="generate_from_disk", name=cls.__name__) diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index 2f8ab20a5..fd84981b6 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -99,6 +99,14 @@ def empty_kb( return empty_kb_factory +@registry.misc("spacy.EmptyKB.v2") +def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]: + def empty_kb_factory(vocab: Vocab, entity_vector_length: int): + return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length) + + return empty_kb_factory + + @registry.misc("spacy.CandidateGenerator.v1") def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: return get_candidates diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index b2d8e7a13..55a04e7ca 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -18,7 +18,6 @@ from thinc.api import CosineDistance, Model, Optimizer, Config from thinc.api import set_dropout_rate from ..kb import KnowledgeBase, Candidate -from ..ml import empty_kb from ..tokens import Doc, Span from .pipe import deserialize_config from .legacy.entity_linker import EntityLinker_v1 @@ -64,6 +63,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] "entity_vector_length": 64, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, "get_candidates_all": {"@misc": "spacy.CandidateAllGenerator.v1"}, + "generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"}, "overwrite": True, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, "use_gold_ents": True, @@ -91,6 +91,7 @@ def make_entity_linker( [KnowledgeBase, Generator[Iterable[Span], None, None]], Iterator[Iterable[Iterable[Candidate]]], ], + generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool, scorer: Optional[Callable], use_gold_ents: bool, @@ -115,6 +116,7 @@ def make_entity_linker( Iterator[Iterable[Iterable[Candidate]]] ]): Function that produces a list of candidates per document, given a certain knowledge base and several textual documents with textual mentions. + generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. scorer (Optional[Callable]): The scoring method. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another component must provide entity annotations. @@ -151,6 +153,7 @@ def make_entity_linker( entity_vector_length=entity_vector_length, get_candidates=get_candidates, get_candidates_all=get_candidates_all, + generate_empty_kb=generate_empty_kb, overwrite=overwrite, scorer=scorer, use_gold_ents=use_gold_ents, @@ -192,6 +195,7 @@ class EntityLinker(TrainablePipe): [KnowledgeBase, Generator[Iterable[Span], None, None]], Iterator[Iterable[Iterable[Candidate]]], ], + generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool = BACKWARD_OVERWRITE, scorer: Optional[Callable] = entity_linker_score, use_gold_ents: bool, @@ -215,14 +219,15 @@ class EntityLinker(TrainablePipe): Callable[ [KnowledgeBase, Generator[Iterable[Span], None, None]], Iterator[Iterable[Iterable[Candidate]]] - ]): Function that produces a list of candidates per document, given a certain knowledge base and several textual - documents with textual mentions. + ]): Function that produces a list of candidates per document, given a certain knowledge base and several + textual documents with textual mentions. + generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another component must provide entity annotations. - candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a generator - yielding entities per document (candidate generator callable is called only once in this case). If False, - the candidate generator is called once per entity. + candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a + generator yielding entities per document (candidate generator callable is called only once in this case). If + False, the candidate generator is called once per entity. threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, prediction is discarded. If None, predictions are not filtered by any threshold. DOCS: https://spacy.io/api/entitylinker#init @@ -241,16 +246,16 @@ class EntityLinker(TrainablePipe): self.model = model self.name = name self.labels_discard = list(labels_discard) + # how many neighbour sentences to take into account self.n_sents = n_sents self.incl_prior = incl_prior self.incl_context = incl_context self.get_candidates = get_candidates self.get_candidates_all = get_candidates_all + self.generate_empty_kb = generate_empty_kb self.cfg: Dict[str, Any] = {"overwrite": overwrite} self.distance = CosineDistance(normalize=False) - # how many neighbour sentences to take into account - # create an empty KB by default - self.kb = empty_kb(entity_vector_length)(self.vocab) + self.kb = generate_empty_kb(self.vocab, entity_vector_length) self.scorer = scorer self.use_gold_ents = use_gold_ents self.candidates_doc_mode = candidates_doc_mode