From 75aee55bc3ad81ea3fde052ef7706a3f25f8e8a5 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 28 Nov 2022 17:29:35 +0100 Subject: [PATCH] Start refactoring of Candidate classes. --- spacy/kb/candidate.pxd | 12 --- spacy/kb/candidate.py | 127 ++++++++++++++++++++++++++++++++ spacy/kb/candidate.pyx | 76 ------------------- spacy/pipeline/entity_linker.py | 9 ++- 4 files changed, 133 insertions(+), 91 deletions(-) delete mode 100644 spacy/kb/candidate.pxd create mode 100644 spacy/kb/candidate.py delete mode 100644 spacy/kb/candidate.pyx diff --git a/spacy/kb/candidate.pxd b/spacy/kb/candidate.pxd deleted file mode 100644 index 942ce9dd0..000000000 --- a/spacy/kb/candidate.pxd +++ /dev/null @@ -1,12 +0,0 @@ -from .kb cimport KnowledgeBase -from libcpp.vector cimport vector -from ..typedefs cimport hash_t - -# Object used by the Entity Linker that summarizes one entity-alias candidate combination. -cdef class Candidate: - cdef readonly KnowledgeBase kb - cdef hash_t entity_hash - cdef float entity_freq - cdef vector[float] entity_vector - cdef hash_t alias_hash - cdef float prior_prob diff --git a/spacy/kb/candidate.py b/spacy/kb/candidate.py new file mode 100644 index 000000000..420ee9c26 --- /dev/null +++ b/spacy/kb/candidate.py @@ -0,0 +1,127 @@ +import abc +from typing import List, Union, Optional + +from spacy import Errors +from ..tokens import Span + + +class Candidate(abc.ABC): + """A `Candidate` object refers to a textual mention (`alias`) that may or may not be resolved + to a specific `entity_id` from a Knowledge Base. This will be used as input for the entity_id linking + algorithm which will disambiguate the various candidates to the correct one. + Each candidate (alias, entity_id) pair is assigned a certain prior probability. + + DOCS: https://spacy.io/api/kb/#candidate-init + """ + + def __init__( + self, mention: str, entity_id: Union[int, str], entity_vector: List[float] + ): + """Create new instance of `Candidate`. Note: has to be a sub-class, otherwise error will be raised. + mention (str): Mention text for this candidate. + entity_id (Union[int, str]): Unique ID of entity_id. + """ + self.mention = mention + self.entity = entity_id + self.entity_vector = entity_vector + + @property + def entity_id(self) -> Union[int, str]: + """RETURNS (Union[int, str]): Entity ID.""" + return self.entity + + def entity_(self) -> Union[int, str]: + """RETURNS (Union[int, str]): Entity ID (for backwards compatibility).""" + return self.entity + + @property + def mention(self) -> str: + """RETURNS (str): Mention.""" + return self.mention + + @property + def entity_vector(self) -> List[float]: + """RETURNS (List[float]): Entity vector.""" + return self.entity_vector + + +class InMemoryLookupKBCandidate(Candidate): + """`Candidate` for InMemoryLookupKBCandidate.""" + + # todo how to resolve circular import issue? -> replace with callable for hash? + def __init__( + self, + kb: KnowledgeBase, + entity_hash, + entity_freq, + entity_vector, + alias_hash, + prior_prob, + ): + """ + prior_prob (float): Prior probability of entity_id for this mention - i.e. the probability that, independent of the + context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In cases in + which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus doesn't) + it might be better to eschew this information and always supply the same value. + """ + self.kb = kb + self.entity_hash = entity_hash + self.entity_freq = entity_freq + self.entity_vector = entity_vector + self.alias_hash = alias_hash + self.prior_prob = prior_prob + + @property + def entity(self) -> int: + """RETURNS (uint64): hash of the entity_id's KB ID/name""" + return self.entity_hash + + @property + def entity_(self) -> str: + """RETURNS (str): ID/name of this entity_id in the KB""" + return self.kb.vocab.strings[self.entity_hash] + + @property + def alias(self) -> int: + """RETURNS (uint64): hash of the alias""" + return self.alias_hash + + @property + def alias_(self) -> str: + """RETURNS (str): ID of the original alias""" + return self.kb.vocab.strings[self.alias_hash] + + @property + def entity_freq(self) -> float: + return self.entity_freq + + @property + def entity_vector(self) -> Iterable[float]: + return self.entity_vector + + @property + def prior_prob(self) -> float: + """RETURNS (List[float]): Entity vector.""" + return self.prior_prob + + +def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: + """ + Return candidate entities for a given mention and fetching appropriate entries from the index. + kb (KnowledgeBase): Knowledge base to query. + mention (Span): Entity mention for which to identify candidates. + RETURNS (Iterable[Candidate]): Identified candidates. + """ + return kb.get_candidates(mention) + + +def get_candidates_all( + kb: KnowledgeBase, mentions: Generator[Iterable[Span], None, None] +) -> Iterator[Iterable[Iterable[Candidate]]]: + """ + Return candidate entities for the given mentions and fetching appropriate entries from the index. + kb (KnowledgeBase): Knowledge base to query. + mention (Generator[Iterable[Span]]): Entity mentions per document for which to identify candidates. + RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document. + """ + return kb.get_candidates_all(mentions) diff --git a/spacy/kb/candidate.pyx b/spacy/kb/candidate.pyx deleted file mode 100644 index 613b70483..000000000 --- a/spacy/kb/candidate.pyx +++ /dev/null @@ -1,76 +0,0 @@ -# cython: infer_types=True, profile=True - -from typing import Iterable, Generator, Iterator -from .kb cimport KnowledgeBase -from ..tokens import Span - -cdef class Candidate: - """A `Candidate` object refers to a textual mention (`alias`) that may or may not be resolved - to a specific `entity` from a Knowledge Base. This will be used as input for the entity linking - algorithm which will disambiguate the various candidates to the correct one. - Each candidate (alias, entity) pair is assigned a certain prior probability. - - DOCS: https://spacy.io/api/kb/#candidate-init - """ - - def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob): - self.kb = kb - self.entity_hash = entity_hash - self.entity_freq = entity_freq - self.entity_vector = entity_vector - self.alias_hash = alias_hash - self.prior_prob = prior_prob - - @property - def entity(self) -> int: - """RETURNS (uint64): hash of the entity's KB ID/name""" - return self.entity_hash - - @property - def entity_(self) -> str: - """RETURNS (str): ID/name of this entity in the KB""" - return self.kb.vocab.strings[self.entity_hash] - - @property - def alias(self) -> int: - """RETURNS (uint64): hash of the alias""" - return self.alias_hash - - @property - def alias_(self) -> str: - """RETURNS (str): ID of the original alias""" - return self.kb.vocab.strings[self.alias_hash] - - @property - def entity_freq(self) -> float: - return self.entity_freq - - @property - def entity_vector(self) -> Iterable[float]: - return self.entity_vector - - @property - def prior_prob(self) -> float: - return self.prior_prob - - -def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: - """ - Return candidate entities for a given mention and fetching appropriate entries from the index. - kb (KnowledgeBase): Knowledge base to query. - mention (Span): Entity mention for which to identify candidates. - RETURNS (Iterable[Candidate]): Identified candidates. - """ - return kb.get_candidates(mention) - - -def get_candidates_all( - kb: KnowledgeBase, mentions: Generator[Iterable[Span], None, None] -) -> Iterator[Iterable[Iterable[Candidate]]]: - """ - Return candidate entities for the given mentions and fetching appropriate entries from the index. - kb (KnowledgeBase): Knowledge base to query. - mention (Generator[Iterable[Span]]): Entity mentions per document for which to identify candidates. - RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document. - """ - return kb.get_candidates_all(mentions) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 55a04e7ca..ef42aac9e 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -532,9 +532,12 @@ class EntityLinker(TrainablePipe): else: random.shuffle(candidates) # set all prior probabilities to 0 if incl_prior=False - prior_probs = xp.asarray([c.prior_prob for c in candidates]) - if not self.incl_prior: - prior_probs = xp.asarray([0.0 for _ in candidates]) + prior_probs = xp.asarray( + [ + 0.0 if self.incl_prior else c.prior_prob + for c in candidates + ] + ) scores = prior_probs # add in similarity from the context if self.incl_context: