diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index fa537edc9..12156590d 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -1,11 +1,11 @@ # cython: infer_types=True, profile=True from pathlib import Path -from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type +from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Optional from cymem.cymem cimport Pool from .candidate import Candidate -from ..tokens import Span +from ..tokens import Span, Doc from ..util import SimpleFrozenList from ..errors import Errors @@ -32,24 +32,25 @@ cdef class KnowledgeBase: self.entity_vector_length = entity_vector_length self.mem = Pool() - def get_candidates_all(self, mentions: Iterator[Iterable[Span]]) -> Iterator[Iterable[Iterable[Candidate]]]: + def get_candidates_all(self, docs: Iterator[Doc]) -> Iterator[Iterable[Iterable[Candidate]]]: """ - Return candidate entities for specified mentions. Each candidate defines the entity, the original alias, - and the prior probability of that alias resolving to that entity. + Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the + entity, the original alias, and the prior probability of that alias resolving to that entity. If no candidate is found for a given mention, an empty list is returned. - mentions (Generator[Iterable[Span]]): Mentions per documents for which to get candidates. - RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document. + docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`). + RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document. """ + for doc in docs: + yield [self.get_candidates(ent_span, doc) for ent_span in doc.ents] - for doc_mentions in mentions: - yield [self.get_candidates(span) for span in doc_mentions] - - def get_candidates(self, mention: Span) -> Iterable[Candidate]: + def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]: """ Return candidate entities for specified text. Each candidate defines the entity, the original alias, and the prior probability of that alias resolving to that entity. If the no candidate is found for a given text, an empty list is returned. + Note that doc is not utilized for further context in this implementation. mention (Span): Mention for which to get candidates. + doc (Optional[Doc]): Doc to use for context. RETURNS (Iterable[Candidate]): Identified candidates. """ raise NotImplementedError( diff --git a/spacy/kb/kb_in_memory.pyx b/spacy/kb/kb_in_memory.pyx index 97ae08e1e..133dc3abb 100644 --- a/spacy/kb/kb_in_memory.pyx +++ b/spacy/kb/kb_in_memory.pyx @@ -1,5 +1,5 @@ # cython: infer_types=True, profile=True -from typing import Iterable, Callable, Dict, Any, Union +from typing import Iterable, Callable, Dict, Any, Union, Optional import srsly from preshed.maps cimport PreshMap @@ -11,7 +11,7 @@ from libcpp.vector cimport vector from pathlib import Path import warnings -from ..tokens import Span +from ..tokens import Span, Doc from ..typedefs cimport hash_t from ..errors import Errors, Warnings from .. import util @@ -231,7 +231,7 @@ cdef class InMemoryLookupKB(KnowledgeBase): alias_entry.probs = probs self._aliases_table[alias_index] = alias_entry - def get_candidates(self, mention: Span) -> Iterable[Candidate]: + def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]: return self.get_alias_candidates(mention.text) # type: ignore def get_alias_candidates(self, str alias) -> Iterable[Candidate]: diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index 293c3910a..99cc31125 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -118,15 +118,15 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: def get_candidates_all( - kb: KnowledgeBase, mentions: Iterator[Iterable[Span]] + kb: KnowledgeBase, docs: Iterator[Doc] ) -> Iterator[Iterable[Iterable[Candidate]]]: """ Return candidate entities for the given mentions and fetching appropriate entries from the index. kb (KnowledgeBase): Knowledge base to query. - mention (Iterator[Iterable[Span]]): Entity mentions per document for which to identify candidates. + docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`). RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document. """ - return kb.get_candidates_all(mentions) + return kb.get_candidates_all(docs) @registry.misc("spacy.CandidateGenerator.v1") diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 9df8c357c..25fb654ac 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -87,8 +87,7 @@ def make_entity_linker( entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates_all: Callable[ - [KnowledgeBase, Iterator[Iterable[Span]]], - Iterator[Iterable[Iterable[Candidate]]], + [KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]] ], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool, @@ -107,11 +106,11 @@ def make_entity_linker( incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. entity_vector_length (int): Size of encoding vectors in the KB. - get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that - produces a list of candidates, given a certain knowledge base and a textual mention. - get_candidates_all (Callable[[KnowledgeBase, Iterator[Iterable[Span]]], Iterator[Iterable[Iterable[Candidate]]]]): - Function that produces a list of candidates per document, given a certain knowledge base and several textual - documents with textual mentions. + get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list of + candidates, given a certain knowledge base and a textual mention. + get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]): Function + that produces a list of candidates per document, given a certain knowledge base and several textual documents + with textual mentions. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. scorer (Optional[Callable]): The scoring method. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another @@ -188,8 +187,7 @@ class EntityLinker(TrainablePipe): entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates_all: Callable[ - [KnowledgeBase, Iterator[Iterable[Span]]], - Iterator[Iterable[Iterable[Candidate]]], + [KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]] ], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], overwrite: bool = BACKWARD_OVERWRITE, @@ -209,9 +207,9 @@ class EntityLinker(TrainablePipe): incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. entity_vector_length (int): Size of encoding vectors in the KB. - get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that - produces a list of candidates, given a certain knowledge base and a textual mention. - get_candidates_all (Callable[[KnowledgeBase, Iterator[Iterable[Span]]], Iterator[Iterable[Iterable[Candidate]]]]): + get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list + of candidates, given a certain knowledge base and a textual mention. + get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]): Function that produces a list of candidates per document, given a certain knowledge base and several textual documents with textual mentions. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. @@ -330,7 +328,6 @@ class EntityLinker(TrainablePipe): If one isn't present, then the update step needs to be skipped. """ - for eg in examples: for ent in eg.predicted.ents: candidates = list(self.get_candidates(self.kb, ent)) @@ -471,10 +468,22 @@ class EntityLinker(TrainablePipe): # Call candidate generator. if self.candidates_doc_mode: + + def _adjust_ents_in_doc(doc: Doc, valid_ent_idx: Iterable[int]) -> Doc: + """ + Generates copy of doc object with only those ents that are candidates are to be retrieved for. + doc (Doc): Doc object to adjust. + valid_ent_idx (Iterable[int]): Indices of entities to keep. + RETURN (doc): Doc instance with only valid entities (i.e. those to retrieve candidates for). + """ + _doc = doc.copy() + _doc.ents = [doc.ents[i] for i in valid_ent_idx] + return _doc + all_ent_cands = self.get_candidates_all( self.kb, ( - [doc.ents[idx] for idx in next(valid_ent_idx_per_doc)] + _adjust_ents_in_doc(doc, next(valid_ent_idx_per_doc)) for doc in docs if len(doc) and len(doc.ents) ), @@ -564,6 +573,7 @@ class EntityLinker(TrainablePipe): method="predict", msg="result variables not of equal length" ) raise RuntimeError(err) + return final_kb_ids def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None: diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 2f30672c1..19b276045 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -503,9 +503,9 @@ def test_el_pipe_configuration(nlp): def get_lowercased_candidates(kb, span): return kb.get_alias_candidates(span.text.lower()) - def get_lowercased_candidates_all(kb, spans_per_doc): - for doc_spans in spans_per_doc: - yield [get_lowercased_candidates(kb, span) for span in doc_spans] + def get_lowercased_candidates_all(kb, docs): + for _doc in docs: + yield [get_lowercased_candidates(kb, ent_span) for ent_span in _doc.ents] @registry.misc("spacy.LowercaseCandidateGenerator.v1") def create_candidates() -> Callable[