Modify candidate retrieval interface to accept docs instead of individual spans.

This commit is contained in:
Raphael Mitsch 2022-12-14 11:51:37 +01:00
parent df6e4ab055
commit 53a24abd8b
5 changed files with 45 additions and 34 deletions

View File

@ -1,11 +1,11 @@
# cython: infer_types=True, profile=True
from pathlib import Path
from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type
from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Optional
from cymem.cymem cimport Pool
from .candidate import Candidate
from ..tokens import Span
from ..tokens import Span, Doc
from ..util import SimpleFrozenList
from ..errors import Errors
@ -32,24 +32,25 @@ cdef class KnowledgeBase:
self.entity_vector_length = entity_vector_length
self.mem = Pool()
def get_candidates_all(self, mentions: Iterator[Iterable[Span]]) -> Iterator[Iterable[Iterable[Candidate]]]:
def get_candidates_all(self, docs: Iterator[Doc]) -> Iterator[Iterable[Iterable[Candidate]]]:
"""
Return candidate entities for specified mentions. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity.
Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the
entity, the original alias, and the prior probability of that alias resolving to that entity.
If no candidate is found for a given mention, an empty list is returned.
mentions (Generator[Iterable[Span]]): Mentions per documents for which to get candidates.
RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
"""
for doc in docs:
yield [self.get_candidates(ent_span, doc) for ent_span in doc.ents]
for doc_mentions in mentions:
yield [self.get_candidates(span) for span in doc_mentions]
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]:
"""
Return candidate entities for specified text. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity.
If the no candidate is found for a given text, an empty list is returned.
Note that doc is not utilized for further context in this implementation.
mention (Span): Mention for which to get candidates.
doc (Optional[Doc]): Doc to use for context.
RETURNS (Iterable[Candidate]): Identified candidates.
"""
raise NotImplementedError(

View File

@ -1,5 +1,5 @@
# cython: infer_types=True, profile=True
from typing import Iterable, Callable, Dict, Any, Union
from typing import Iterable, Callable, Dict, Any, Union, Optional
import srsly
from preshed.maps cimport PreshMap
@ -11,7 +11,7 @@ from libcpp.vector cimport vector
from pathlib import Path
import warnings
from ..tokens import Span
from ..tokens import Span, Doc
from ..typedefs cimport hash_t
from ..errors import Errors, Warnings
from .. import util
@ -231,7 +231,7 @@ cdef class InMemoryLookupKB(KnowledgeBase):
alias_entry.probs = probs
self._aliases_table[alias_index] = alias_entry
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]:
return self.get_alias_candidates(mention.text) # type: ignore
def get_alias_candidates(self, str alias) -> Iterable[Candidate]:

View File

@ -118,15 +118,15 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
def get_candidates_all(
kb: KnowledgeBase, mentions: Iterator[Iterable[Span]]
kb: KnowledgeBase, docs: Iterator[Doc]
) -> Iterator[Iterable[Iterable[Candidate]]]:
"""
Return candidate entities for the given mentions and fetching appropriate entries from the index.
kb (KnowledgeBase): Knowledge base to query.
mention (Iterator[Iterable[Span]]): Entity mentions per document for which to identify candidates.
docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
"""
return kb.get_candidates_all(mentions)
return kb.get_candidates_all(docs)
@registry.misc("spacy.CandidateGenerator.v1")

View File

@ -87,8 +87,7 @@ def make_entity_linker(
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_all: Callable[
[KnowledgeBase, Iterator[Iterable[Span]]],
Iterator[Iterable[Iterable[Candidate]]],
[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool,
@ -107,11 +106,11 @@ def make_entity_linker(
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
get_candidates_all (Callable[[KnowledgeBase, Iterator[Iterable[Span]]], Iterator[Iterable[Iterable[Candidate]]]]):
Function that produces a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list of
candidates, given a certain knowledge base and a textual mention.
get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]): Function
that produces a list of candidates per document, given a certain knowledge base and several textual documents
with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
@ -188,8 +187,7 @@ class EntityLinker(TrainablePipe):
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_all: Callable[
[KnowledgeBase, Iterator[Iterable[Span]]],
Iterator[Iterable[Iterable[Candidate]]],
[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool = BACKWARD_OVERWRITE,
@ -209,9 +207,9 @@ class EntityLinker(TrainablePipe):
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
get_candidates_all (Callable[[KnowledgeBase, Iterator[Iterable[Span]]], Iterator[Iterable[Iterable[Candidate]]]]):
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list
of candidates, given a certain knowledge base and a textual mention.
get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]):
Function that produces a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
@ -330,7 +328,6 @@ class EntityLinker(TrainablePipe):
If one isn't present, then the update step needs to be skipped.
"""
for eg in examples:
for ent in eg.predicted.ents:
candidates = list(self.get_candidates(self.kb, ent))
@ -471,10 +468,22 @@ class EntityLinker(TrainablePipe):
# Call candidate generator.
if self.candidates_doc_mode:
def _adjust_ents_in_doc(doc: Doc, valid_ent_idx: Iterable[int]) -> Doc:
"""
Generates copy of doc object with only those ents that are candidates are to be retrieved for.
doc (Doc): Doc object to adjust.
valid_ent_idx (Iterable[int]): Indices of entities to keep.
RETURN (doc): Doc instance with only valid entities (i.e. those to retrieve candidates for).
"""
_doc = doc.copy()
_doc.ents = [doc.ents[i] for i in valid_ent_idx]
return _doc
all_ent_cands = self.get_candidates_all(
self.kb,
(
[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)]
_adjust_ents_in_doc(doc, next(valid_ent_idx_per_doc))
for doc in docs
if len(doc) and len(doc.ents)
),
@ -564,6 +573,7 @@ class EntityLinker(TrainablePipe):
method="predict", msg="result variables not of equal length"
)
raise RuntimeError(err)
return final_kb_ids
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:

View File

@ -503,9 +503,9 @@ def test_el_pipe_configuration(nlp):
def get_lowercased_candidates(kb, span):
return kb.get_alias_candidates(span.text.lower())
def get_lowercased_candidates_all(kb, spans_per_doc):
for doc_spans in spans_per_doc:
yield [get_lowercased_candidates(kb, span) for span in doc_spans]
def get_lowercased_candidates_all(kb, docs):
for _doc in docs:
yield [get_lowercased_candidates(kb, ent_span) for ent_span in _doc.ents]
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[