Modify candidate retrieval interface to accept docs instead of individual spans.

This commit is contained in:
Raphael Mitsch 2022-12-14 11:51:37 +01:00
parent df6e4ab055
commit 53a24abd8b
5 changed files with 45 additions and 34 deletions

View File

@ -1,11 +1,11 @@
# cython: infer_types=True, profile=True # cython: infer_types=True, profile=True
from pathlib import Path from pathlib import Path
from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type, Optional
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from .candidate import Candidate from .candidate import Candidate
from ..tokens import Span from ..tokens import Span, Doc
from ..util import SimpleFrozenList from ..util import SimpleFrozenList
from ..errors import Errors from ..errors import Errors
@ -32,24 +32,25 @@ cdef class KnowledgeBase:
self.entity_vector_length = entity_vector_length self.entity_vector_length = entity_vector_length
self.mem = Pool() self.mem = Pool()
def get_candidates_all(self, mentions: Iterator[Iterable[Span]]) -> Iterator[Iterable[Iterable[Candidate]]]: def get_candidates_all(self, docs: Iterator[Doc]) -> Iterator[Iterable[Iterable[Candidate]]]:
""" """
Return candidate entities for specified mentions. Each candidate defines the entity, the original alias, Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the
and the prior probability of that alias resolving to that entity. entity, the original alias, and the prior probability of that alias resolving to that entity.
If no candidate is found for a given mention, an empty list is returned. If no candidate is found for a given mention, an empty list is returned.
mentions (Generator[Iterable[Span]]): Mentions per documents for which to get candidates. docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document. RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
""" """
for doc in docs:
yield [self.get_candidates(ent_span, doc) for ent_span in doc.ents]
for doc_mentions in mentions: def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]:
yield [self.get_candidates(span) for span in doc_mentions]
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
""" """
Return candidate entities for specified text. Each candidate defines the entity, the original alias, Return candidate entities for specified text. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity. and the prior probability of that alias resolving to that entity.
If the no candidate is found for a given text, an empty list is returned. If the no candidate is found for a given text, an empty list is returned.
Note that doc is not utilized for further context in this implementation.
mention (Span): Mention for which to get candidates. mention (Span): Mention for which to get candidates.
doc (Optional[Doc]): Doc to use for context.
RETURNS (Iterable[Candidate]): Identified candidates. RETURNS (Iterable[Candidate]): Identified candidates.
""" """
raise NotImplementedError( raise NotImplementedError(

View File

@ -1,5 +1,5 @@
# cython: infer_types=True, profile=True # cython: infer_types=True, profile=True
from typing import Iterable, Callable, Dict, Any, Union from typing import Iterable, Callable, Dict, Any, Union, Optional
import srsly import srsly
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
@ -11,7 +11,7 @@ from libcpp.vector cimport vector
from pathlib import Path from pathlib import Path
import warnings import warnings
from ..tokens import Span from ..tokens import Span, Doc
from ..typedefs cimport hash_t from ..typedefs cimport hash_t
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from .. import util from .. import util
@ -231,7 +231,7 @@ cdef class InMemoryLookupKB(KnowledgeBase):
alias_entry.probs = probs alias_entry.probs = probs
self._aliases_table[alias_index] = alias_entry self._aliases_table[alias_index] = alias_entry
def get_candidates(self, mention: Span) -> Iterable[Candidate]: def get_candidates(self, mention: Span, doc: Optional[Doc] = None) -> Iterable[Candidate]:
return self.get_alias_candidates(mention.text) # type: ignore return self.get_alias_candidates(mention.text) # type: ignore
def get_alias_candidates(self, str alias) -> Iterable[Candidate]: def get_alias_candidates(self, str alias) -> Iterable[Candidate]:

View File

@ -118,15 +118,15 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
def get_candidates_all( def get_candidates_all(
kb: KnowledgeBase, mentions: Iterator[Iterable[Span]] kb: KnowledgeBase, docs: Iterator[Doc]
) -> Iterator[Iterable[Iterable[Candidate]]]: ) -> Iterator[Iterable[Iterable[Candidate]]]:
""" """
Return candidate entities for the given mentions and fetching appropriate entries from the index. Return candidate entities for the given mentions and fetching appropriate entries from the index.
kb (KnowledgeBase): Knowledge base to query. kb (KnowledgeBase): Knowledge base to query.
mention (Iterator[Iterable[Span]]): Entity mentions per document for which to identify candidates. docs (Iterator[Doc]): Doc instances with mentions (stored in `.ent`).
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document. RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
""" """
return kb.get_candidates_all(mentions) return kb.get_candidates_all(docs)
@registry.misc("spacy.CandidateGenerator.v1") @registry.misc("spacy.CandidateGenerator.v1")

View File

@ -87,8 +87,7 @@ def make_entity_linker(
entity_vector_length: int, entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_all: Callable[ get_candidates_all: Callable[
[KnowledgeBase, Iterator[Iterable[Span]]], [KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
Iterator[Iterable[Iterable[Candidate]]],
], ],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool, overwrite: bool,
@ -107,11 +106,11 @@ def make_entity_linker(
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model. incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB. entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list of
produces a list of candidates, given a certain knowledge base and a textual mention. candidates, given a certain knowledge base and a textual mention.
get_candidates_all (Callable[[KnowledgeBase, Iterator[Iterable[Span]]], Iterator[Iterable[Iterable[Candidate]]]]): get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]): Function
Function that produces a list of candidates per document, given a certain knowledge base and several textual that produces a list of candidates per document, given a certain knowledge base and several textual documents
documents with textual mentions. with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method. scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
@ -188,8 +187,7 @@ class EntityLinker(TrainablePipe):
entity_vector_length: int, entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_all: Callable[ get_candidates_all: Callable[
[KnowledgeBase, Iterator[Iterable[Span]]], [KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]
Iterator[Iterable[Iterable[Candidate]]],
], ],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool = BACKWARD_OVERWRITE, overwrite: bool = BACKWARD_OVERWRITE,
@ -209,9 +207,9 @@ class EntityLinker(TrainablePipe):
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model. incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB. entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function producing a list
produces a list of candidates, given a certain knowledge base and a textual mention. of candidates, given a certain knowledge base and a textual mention.
get_candidates_all (Callable[[KnowledgeBase, Iterator[Iterable[Span]]], Iterator[Iterable[Iterable[Candidate]]]]): get_candidates_all (Callable[[KnowledgeBase, Iterator[Doc]], Iterator[Iterable[Iterable[Candidate]]]]):
Function that produces a list of candidates per document, given a certain knowledge base and several textual Function that produces a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions. documents with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
@ -330,7 +328,6 @@ class EntityLinker(TrainablePipe):
If one isn't present, then the update step needs to be skipped. If one isn't present, then the update step needs to be skipped.
""" """
for eg in examples: for eg in examples:
for ent in eg.predicted.ents: for ent in eg.predicted.ents:
candidates = list(self.get_candidates(self.kb, ent)) candidates = list(self.get_candidates(self.kb, ent))
@ -471,10 +468,22 @@ class EntityLinker(TrainablePipe):
# Call candidate generator. # Call candidate generator.
if self.candidates_doc_mode: if self.candidates_doc_mode:
def _adjust_ents_in_doc(doc: Doc, valid_ent_idx: Iterable[int]) -> Doc:
"""
Generates copy of doc object with only those ents that are candidates are to be retrieved for.
doc (Doc): Doc object to adjust.
valid_ent_idx (Iterable[int]): Indices of entities to keep.
RETURN (doc): Doc instance with only valid entities (i.e. those to retrieve candidates for).
"""
_doc = doc.copy()
_doc.ents = [doc.ents[i] for i in valid_ent_idx]
return _doc
all_ent_cands = self.get_candidates_all( all_ent_cands = self.get_candidates_all(
self.kb, self.kb,
( (
[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)] _adjust_ents_in_doc(doc, next(valid_ent_idx_per_doc))
for doc in docs for doc in docs
if len(doc) and len(doc.ents) if len(doc) and len(doc.ents)
), ),
@ -564,6 +573,7 @@ class EntityLinker(TrainablePipe):
method="predict", msg="result variables not of equal length" method="predict", msg="result variables not of equal length"
) )
raise RuntimeError(err) raise RuntimeError(err)
return final_kb_ids return final_kb_ids
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None: def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:

View File

@ -503,9 +503,9 @@ def test_el_pipe_configuration(nlp):
def get_lowercased_candidates(kb, span): def get_lowercased_candidates(kb, span):
return kb.get_alias_candidates(span.text.lower()) return kb.get_alias_candidates(span.text.lower())
def get_lowercased_candidates_all(kb, spans_per_doc): def get_lowercased_candidates_all(kb, docs):
for doc_spans in spans_per_doc: for _doc in docs:
yield [get_lowercased_candidates(kb, span) for span in doc_spans] yield [get_lowercased_candidates(kb, ent_span) for ent_span in _doc.ents]
@registry.misc("spacy.LowercaseCandidateGenerator.v1") @registry.misc("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[ def create_candidates() -> Callable[