Refacor Candidate attributes and their usage.

This commit is contained in:
Raphael Mitsch 2023-03-05 13:49:13 +01:00
parent 94e57d0ed5
commit 38dce966e5
5 changed files with 58 additions and 64 deletions

View File

@ -14,41 +14,46 @@ class Candidate(abc.ABC):
def __init__(
self,
mention: str,
entity_id: int,
entity_name: str,
entity_id: Union[str, int],
entity_vector: List[float],
prior_prob: float,
):
"""Initializes properties of `Candidate` instance.
mention (str): Mention text for this candidate.
entity_id (int): Unique entity ID.
entity_name (str): Entity name.
entity_id (Union[str, int]): Unique entity ID.
entity_vector (List[float]): Entity embedding.
prior_prob (float): Prior probability of entity for this mention - i.e. the probability that, independent of
the context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In
cases in which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus
doesn't) it might be better to eschew this information and always supply the same value.
"""
self._mention_ = mention
self._entity = entity_id
self._entity_ = entity_name
self._mention = mention
self._entity_id = entity_id
# Note that hashing an int value yields the same int value.
self._entity_id_hash = hash(entity_id)
self._entity_vector = entity_vector
self._prior_prob = prior_prob
@property
def entity(self) -> int:
"""RETURNS (int): Unique entity ID."""
return self._entity
def entity_id(self) -> Union[str, int]:
"""RETURNS (Union[str, int]): Unique entity ID."""
return self._entity_id
@property
def entity_(self) -> str:
"""RETURNS (int): Entity name."""
return self._entity_
def entity_id_int(self) -> int:
"""RETURNS (int): Numerical representation of entity ID (if entity ID is numerical, this is just the entity ID,
otherwise the hash of the entity ID string)."""
return self._entity_id_hash
@property
def mention_(self) -> str:
def entity_id_str(self) -> str:
"""RETURNS (str): String representation of entity ID."""
return str(self._entity_id)
@property
def mention(self) -> str:
"""RETURNS (str): Mention."""
return self._mention_
return self._mention
@property
def entity_vector(self) -> List[float]:
@ -66,49 +71,40 @@ class InMemoryCandidate(Candidate):
def __init__(
self,
retrieve_string_from_hash: Callable[[int], str],
entity_hash: int,
entity_freq: int,
hash_to_str: Callable[[int], str],
entity_id: int,
mention: str,
entity_vector: List[float],
mention_hash: int,
prior_prob: float,
entity_freq: int
):
"""
retrieve_string_from_hash (Callable[[int], str]): Callable retrieving entity name from provided entity/vocab
hash.
entity_hash (str): Hashed entity name /ID.
hash_to_str (Callable[[int], str]): Callable retrieving entity name from provided entity/vocab hash.
entity_id (str): Entity ID as hash that can be looked up with InMemoryKB.vocab.strings.__getitem__().
entity_freq (int): Entity frequency in KB corpus.
entity_vector (List[float]): Entity embedding.
mention_hash (int): Hashed mention.
mention (str): Mention.
prior_prob (float): Prior probability of entity for this mention - i.e. the probability that, independent of
the context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In
cases in which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus
doesn't) it might be better to eschew this information and always supply the same value.
"""
super().__init__(
mention=retrieve_string_from_hash(mention_hash),
entity_id=entity_hash,
entity_name=retrieve_string_from_hash(entity_hash),
mention=mention,
entity_id=entity_id,
entity_vector=entity_vector,
prior_prob=prior_prob,
)
self._retrieve_string_from_hash = retrieve_string_from_hash
self._entity = entity_hash
self._hash_to_str = hash_to_str
self._entity_freq = entity_freq
self._mention = mention_hash
self._prior_prob = prior_prob
@property
def entity(self) -> int:
"""RETURNS (int): hash of the entity_id's KB ID/name"""
return self._entity
@property
def mention(self) -> int:
"""RETURNS (int): Mention hash."""
return self._mention
self._entity_id_str = self._hash_to_str(self._entity_id)
@property
def entity_freq(self) -> float:
"""RETURNS (float): Relative entity frequency."""
return self._entity_freq
@property
def entity_id_str(self) -> str:
"""RETURNS (str): String representation of entity ID."""
return self._entity_id_str

View File

@ -240,12 +240,12 @@ cdef class InMemoryLookupKB(KnowledgeBase):
return [
InMemoryCandidate(
retrieve_string_from_hash=self.vocab.strings.__getitem__,
entity_hash=self._entries[entry_index].entity_hash,
entity_freq=self._entries[entry_index].freq,
hash_to_str=self.vocab.strings.__getitem__,
entity_id=self._entries[entry_index].entity_hash,
mention=alias,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
mention_hash=alias_hash,
prior_prob=prior_prob
prior_prob=prior_prob,
entity_freq=self._entries[entry_index].freq
)
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
if entry_index != 0

View File

@ -535,12 +535,12 @@ class EntityLinker(TrainablePipe):
)
elif len(candidates) == 1 and self.threshold is None:
# shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_)
final_kb_ids.append(candidates[0].entity_id_str)
self._add_activations(
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=[1.0],
ents=[candidates[0].entity],
ents=[candidates[0].entity_id_int],
)
else:
random.shuffle(candidates)
@ -570,7 +570,7 @@ class EntityLinker(TrainablePipe):
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs * sims)
final_kb_ids.append(
candidates[scores.argmax().item()].entity_
candidates[scores.argmax().item()].entity_id_str
if self.threshold is None
or scores.max() >= self.threshold
else EntityLinker.NIL
@ -579,7 +579,7 @@ class EntityLinker(TrainablePipe):
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=scores,
ents=[c.entity for c in candidates],
ents=[c.entity_id_int for c in candidates],
)
self._add_doc_activations(
docs_scores=docs_scores,

View File

@ -468,8 +468,8 @@ def test_candidate_generation(nlp):
assert len(get_candidates(mykb, shrubbery_ent)) == 0
# test the content of the candidates
assert get_candidates(mykb, adam_ent)[0].entity_ == "Q2"
assert get_candidates(mykb, adam_ent)[0].mention_ == "adam"
assert get_candidates(mykb, adam_ent)[0].entity_id_str == "Q2"
assert get_candidates(mykb, adam_ent)[0].mention == "adam"
assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
@ -560,10 +560,9 @@ def test_vocab_serialization(nlp):
candidates = mykb._get_alias_candidates("adam")
assert len(candidates) == 1
assert candidates[0].entity == q2_hash
assert candidates[0].entity_ == "Q2"
assert candidates[0].mention == adam_hash
assert candidates[0].mention_ == "adam"
assert candidates[0].entity_id_int == q2_hash
assert candidates[0].entity_id_str == "Q2"
assert candidates[0].mention == "adam"
with make_tempdir() as d:
mykb.to_disk(d / "kb")
@ -572,10 +571,9 @@ def test_vocab_serialization(nlp):
candidates = kb_new_vocab._get_alias_candidates("adam")
assert len(candidates) == 1
assert candidates[0].entity == q2_hash
assert candidates[0].entity_ == "Q2"
assert candidates[0].mention == adam_hash
assert candidates[0].mention_ == "adam"
assert candidates[0].entity_id_int == q2_hash
assert candidates[0].entity_id_str == "Q2"
assert candidates[0].mention == "adam"
assert kb_new_vocab.get_vector("Q2") == [2]
assert_almost_equal(kb_new_vocab.get_prior_prob("Q2", "douglas"), 0.4)

View File

@ -63,19 +63,19 @@ def _check_kb(kb):
assert alias_string not in kb.get_alias_strings()
# check candidates & probabilities
candidates = sorted(kb._get_alias_candidates("double07"), key=lambda x: x.entity_)
candidates = sorted(kb._get_alias_candidates("double07"), key=lambda x: x.entity_id_str)
assert len(candidates) == 2
assert candidates[0].entity_ == "Q007"
assert candidates[0].entity_id_str == "Q007"
assert 6.999 < candidates[0].entity_freq < 7.01
assert candidates[0].entity_vector == [0, 0, 7]
assert candidates[0].mention_ == "double07"
assert candidates[0].mention == "double07"
assert 0.899 < candidates[0].prior_prob < 0.901
assert candidates[1].entity_ == "Q17"
assert candidates[1].entity_id_str == "Q17"
assert 1.99 < candidates[1].entity_freq < 2.01
assert candidates[1].entity_vector == [7, 1, 0]
assert candidates[1].mention_ == "double07"
assert candidates[1].mention == "double07"
assert 0.099 < candidates[1].prior_prob < 0.101