diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 2a5f3962d..b371ca9a4 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -8,6 +8,7 @@ from thinc.api import set_dropout_rate import warnings from ..kb import KnowledgeBase, Candidate +from ..ml import empty_kb from ..tokens import Doc from .pipe import Pipe, deserialize_config from ..language import Language @@ -41,11 +42,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], assigns=["token.ent_kb_id"], default_config={ - "kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 64}, "model": DEFAULT_NEL_MODEL, "labels_discard": [], "incl_prior": True, "incl_context": True, + "entity_vector_length": 64, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, }, default_score_weights={ @@ -58,11 +59,11 @@ def make_entity_linker( nlp: Language, name: str, model: Model, - kb_loader: Callable[[Vocab], KnowledgeBase], *, labels_discard: Iterable[str], incl_prior: bool, incl_context: bool, + entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]], ): """Construct an EntityLinker component. @@ -70,19 +71,21 @@ def make_entity_linker( model (Model[List[Doc], Floats2d]): A model that learns document vector representations. Given a batch of Doc objects, it should return a single array, with one row per item in the batch. - kb (KnowledgeBase): The knowledge-base to link entities to. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. + entity_vector_length (int): Size of encoding vectors in the KB. + get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that + produces a list of candidates, given a certain knowledge base and a textual mention. """ return EntityLinker( nlp.vocab, model, name, - kb_loader=kb_loader, labels_discard=labels_discard, incl_prior=incl_prior, incl_context=incl_context, + entity_vector_length=entity_vector_length, get_candidates=get_candidates, ) @@ -101,10 +104,10 @@ class EntityLinker(Pipe): model: Model, name: str = "entity_linker", *, - kb_loader: Callable[[Vocab], KnowledgeBase], labels_discard: Iterable[str], incl_prior: bool, incl_context: bool, + entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]], ) -> None: """Initialize an entity linker. @@ -113,10 +116,12 @@ class EntityLinker(Pipe): model (thinc.api.Model): The Thinc Model powering the pipeline component. name (str): The component instance name, used to add entries to the losses during training. - kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. + entity_vector_length (int): Size of encoding vectors in the KB. + get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that + produces a list of candidates, given a certain knowledge base and a textual mention. DOCS: https://nightly.spacy.io/api/entitylinker#init """ @@ -127,15 +132,17 @@ class EntityLinker(Pipe): "labels_discard": list(labels_discard), "incl_prior": incl_prior, "incl_context": incl_context, + "entity_vector_length": entity_vector_length, } - self.kb = kb_loader(self.vocab) self.get_candidates = get_candidates self.cfg = dict(cfg) self.distance = CosineDistance(normalize=False) # how many neightbour sentences to take into account self.n_sents = cfg.get("n_sents", 0) + # create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'. + self.kb = empty_kb(entity_vector_length)(self.vocab) - def _require_kb(self) -> None: + def validate_kb(self) -> None: # Raise an error if the knowledge base is not initialized. if len(self.kb) == 0: raise ValueError(Errors.E139.format(name=self.name)) @@ -145,6 +152,7 @@ class EntityLinker(Pipe): get_examples: Callable[[], Iterable[Example]], *, nlp: Optional[Language] = None, + kb_loader: Callable[[Vocab], KnowledgeBase] = None, ): """Initialize the pipe for training, using a representative set of data examples. @@ -152,11 +160,17 @@ class EntityLinker(Pipe): get_examples (Callable[[], Iterable[Example]]): Function that returns a representative sample of gold-standard Example objects. nlp (Language): The current nlp object the component is part of. + kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance. + Note that providing this argument, will overwrite all data accumulated in the current KB. + Use this only when loading a KB as-such from file. DOCS: https://nightly.spacy.io/api/entitylinker#initialize """ self._ensure_examples(get_examples) - self._require_kb() + if kb_loader is not None: + self.kb = kb_loader(self.vocab) + self.cfg["entity_vector_length"] = self.kb.entity_vector_length + self.validate_kb() nO = self.kb.entity_vector_length doc_sample = [] vector_sample = [] @@ -192,7 +206,7 @@ class EntityLinker(Pipe): DOCS: https://nightly.spacy.io/api/entitylinker#update """ - self._require_kb() + self.validate_kb() if losses is None: losses = {} losses.setdefault(self.name, 0.0) @@ -303,7 +317,7 @@ class EntityLinker(Pipe): DOCS: https://nightly.spacy.io/api/entitylinker#predict """ - self._require_kb() + self.validate_kb() entity_count = 0 final_kb_ids = [] if not docs: