Refacor Candidate attributes and their usage.

This commit is contained in:
Raphael Mitsch 2023-03-05 13:49:13 +01:00
parent 94e57d0ed5
commit 38dce966e5
5 changed files with 58 additions and 64 deletions

View File

@ -14,41 +14,46 @@ class Candidate(abc.ABC):
def __init__( def __init__(
self, self,
mention: str, mention: str,
entity_id: int, entity_id: Union[str, int],
entity_name: str,
entity_vector: List[float], entity_vector: List[float],
prior_prob: float, prior_prob: float,
): ):
"""Initializes properties of `Candidate` instance. """Initializes properties of `Candidate` instance.
mention (str): Mention text for this candidate. mention (str): Mention text for this candidate.
entity_id (int): Unique entity ID. entity_id (Union[str, int]): Unique entity ID.
entity_name (str): Entity name.
entity_vector (List[float]): Entity embedding. entity_vector (List[float]): Entity embedding.
prior_prob (float): Prior probability of entity for this mention - i.e. the probability that, independent of prior_prob (float): Prior probability of entity for this mention - i.e. the probability that, independent of
the context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In the context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In
cases in which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus cases in which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus
doesn't) it might be better to eschew this information and always supply the same value. doesn't) it might be better to eschew this information and always supply the same value.
""" """
self._mention_ = mention self._mention = mention
self._entity = entity_id self._entity_id = entity_id
self._entity_ = entity_name # Note that hashing an int value yields the same int value.
self._entity_id_hash = hash(entity_id)
self._entity_vector = entity_vector self._entity_vector = entity_vector
self._prior_prob = prior_prob self._prior_prob = prior_prob
@property @property
def entity(self) -> int: def entity_id(self) -> Union[str, int]:
"""RETURNS (int): Unique entity ID.""" """RETURNS (Union[str, int]): Unique entity ID."""
return self._entity return self._entity_id
@property @property
def entity_(self) -> str: def entity_id_int(self) -> int:
"""RETURNS (int): Entity name.""" """RETURNS (int): Numerical representation of entity ID (if entity ID is numerical, this is just the entity ID,
return self._entity_ otherwise the hash of the entity ID string)."""
return self._entity_id_hash
@property @property
def mention_(self) -> str: def entity_id_str(self) -> str:
"""RETURNS (str): String representation of entity ID."""
return str(self._entity_id)
@property
def mention(self) -> str:
"""RETURNS (str): Mention.""" """RETURNS (str): Mention."""
return self._mention_ return self._mention
@property @property
def entity_vector(self) -> List[float]: def entity_vector(self) -> List[float]:
@ -66,49 +71,40 @@ class InMemoryCandidate(Candidate):
def __init__( def __init__(
self, self,
retrieve_string_from_hash: Callable[[int], str], hash_to_str: Callable[[int], str],
entity_hash: int, entity_id: int,
entity_freq: int, mention: str,
entity_vector: List[float], entity_vector: List[float],
mention_hash: int,
prior_prob: float, prior_prob: float,
entity_freq: int
): ):
""" """
retrieve_string_from_hash (Callable[[int], str]): Callable retrieving entity name from provided entity/vocab hash_to_str (Callable[[int], str]): Callable retrieving entity name from provided entity/vocab hash.
hash. entity_id (str): Entity ID as hash that can be looked up with InMemoryKB.vocab.strings.__getitem__().
entity_hash (str): Hashed entity name /ID.
entity_freq (int): Entity frequency in KB corpus. entity_freq (int): Entity frequency in KB corpus.
entity_vector (List[float]): Entity embedding. entity_vector (List[float]): Entity embedding.
mention_hash (int): Hashed mention. mention (str): Mention.
prior_prob (float): Prior probability of entity for this mention - i.e. the probability that, independent of prior_prob (float): Prior probability of entity for this mention - i.e. the probability that, independent of
the context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In the context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In
cases in which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus cases in which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus
doesn't) it might be better to eschew this information and always supply the same value. doesn't) it might be better to eschew this information and always supply the same value.
""" """
super().__init__( super().__init__(
mention=retrieve_string_from_hash(mention_hash), mention=mention,
entity_id=entity_hash, entity_id=entity_id,
entity_name=retrieve_string_from_hash(entity_hash),
entity_vector=entity_vector, entity_vector=entity_vector,
prior_prob=prior_prob, prior_prob=prior_prob,
) )
self._retrieve_string_from_hash = retrieve_string_from_hash self._hash_to_str = hash_to_str
self._entity = entity_hash
self._entity_freq = entity_freq self._entity_freq = entity_freq
self._mention = mention_hash self._entity_id_str = self._hash_to_str(self._entity_id)
self._prior_prob = prior_prob
@property
def entity(self) -> int:
"""RETURNS (int): hash of the entity_id's KB ID/name"""
return self._entity
@property
def mention(self) -> int:
"""RETURNS (int): Mention hash."""
return self._mention
@property @property
def entity_freq(self) -> float: def entity_freq(self) -> float:
"""RETURNS (float): Relative entity frequency.""" """RETURNS (float): Relative entity frequency."""
return self._entity_freq return self._entity_freq
@property
def entity_id_str(self) -> str:
"""RETURNS (str): String representation of entity ID."""
return self._entity_id_str

View File

@ -240,12 +240,12 @@ cdef class InMemoryLookupKB(KnowledgeBase):
return [ return [
InMemoryCandidate( InMemoryCandidate(
retrieve_string_from_hash=self.vocab.strings.__getitem__, hash_to_str=self.vocab.strings.__getitem__,
entity_hash=self._entries[entry_index].entity_hash, entity_id=self._entries[entry_index].entity_hash,
entity_freq=self._entries[entry_index].freq, mention=alias,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index], entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
mention_hash=alias_hash, prior_prob=prior_prob,
prior_prob=prior_prob entity_freq=self._entries[entry_index].freq
) )
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs) for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
if entry_index != 0 if entry_index != 0

View File

@ -535,12 +535,12 @@ class EntityLinker(TrainablePipe):
) )
elif len(candidates) == 1 and self.threshold is None: elif len(candidates) == 1 and self.threshold is None:
# shortcut for efficiency reasons: take the 1 candidate # shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_) final_kb_ids.append(candidates[0].entity_id_str)
self._add_activations( self._add_activations(
doc_scores=doc_scores, doc_scores=doc_scores,
doc_ents=doc_ents, doc_ents=doc_ents,
scores=[1.0], scores=[1.0],
ents=[candidates[0].entity], ents=[candidates[0].entity_id_int],
) )
else: else:
random.shuffle(candidates) random.shuffle(candidates)
@ -570,7 +570,7 @@ class EntityLinker(TrainablePipe):
raise ValueError(Errors.E161) raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs * sims) scores = prior_probs + sims - (prior_probs * sims)
final_kb_ids.append( final_kb_ids.append(
candidates[scores.argmax().item()].entity_ candidates[scores.argmax().item()].entity_id_str
if self.threshold is None if self.threshold is None
or scores.max() >= self.threshold or scores.max() >= self.threshold
else EntityLinker.NIL else EntityLinker.NIL
@ -579,7 +579,7 @@ class EntityLinker(TrainablePipe):
doc_scores=doc_scores, doc_scores=doc_scores,
doc_ents=doc_ents, doc_ents=doc_ents,
scores=scores, scores=scores,
ents=[c.entity for c in candidates], ents=[c.entity_id_int for c in candidates],
) )
self._add_doc_activations( self._add_doc_activations(
docs_scores=docs_scores, docs_scores=docs_scores,

View File

@ -468,8 +468,8 @@ def test_candidate_generation(nlp):
assert len(get_candidates(mykb, shrubbery_ent)) == 0 assert len(get_candidates(mykb, shrubbery_ent)) == 0
# test the content of the candidates # test the content of the candidates
assert get_candidates(mykb, adam_ent)[0].entity_ == "Q2" assert get_candidates(mykb, adam_ent)[0].entity_id_str == "Q2"
assert get_candidates(mykb, adam_ent)[0].mention_ == "adam" assert get_candidates(mykb, adam_ent)[0].mention == "adam"
assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12) assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9) assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
@ -560,10 +560,9 @@ def test_vocab_serialization(nlp):
candidates = mykb._get_alias_candidates("adam") candidates = mykb._get_alias_candidates("adam")
assert len(candidates) == 1 assert len(candidates) == 1
assert candidates[0].entity == q2_hash assert candidates[0].entity_id_int == q2_hash
assert candidates[0].entity_ == "Q2" assert candidates[0].entity_id_str == "Q2"
assert candidates[0].mention == adam_hash assert candidates[0].mention == "adam"
assert candidates[0].mention_ == "adam"
with make_tempdir() as d: with make_tempdir() as d:
mykb.to_disk(d / "kb") mykb.to_disk(d / "kb")
@ -572,10 +571,9 @@ def test_vocab_serialization(nlp):
candidates = kb_new_vocab._get_alias_candidates("adam") candidates = kb_new_vocab._get_alias_candidates("adam")
assert len(candidates) == 1 assert len(candidates) == 1
assert candidates[0].entity == q2_hash assert candidates[0].entity_id_int == q2_hash
assert candidates[0].entity_ == "Q2" assert candidates[0].entity_id_str == "Q2"
assert candidates[0].mention == adam_hash assert candidates[0].mention == "adam"
assert candidates[0].mention_ == "adam"
assert kb_new_vocab.get_vector("Q2") == [2] assert kb_new_vocab.get_vector("Q2") == [2]
assert_almost_equal(kb_new_vocab.get_prior_prob("Q2", "douglas"), 0.4) assert_almost_equal(kb_new_vocab.get_prior_prob("Q2", "douglas"), 0.4)

View File

@ -63,19 +63,19 @@ def _check_kb(kb):
assert alias_string not in kb.get_alias_strings() assert alias_string not in kb.get_alias_strings()
# check candidates & probabilities # check candidates & probabilities
candidates = sorted(kb._get_alias_candidates("double07"), key=lambda x: x.entity_) candidates = sorted(kb._get_alias_candidates("double07"), key=lambda x: x.entity_id_str)
assert len(candidates) == 2 assert len(candidates) == 2
assert candidates[0].entity_ == "Q007" assert candidates[0].entity_id_str == "Q007"
assert 6.999 < candidates[0].entity_freq < 7.01 assert 6.999 < candidates[0].entity_freq < 7.01
assert candidates[0].entity_vector == [0, 0, 7] assert candidates[0].entity_vector == [0, 0, 7]
assert candidates[0].mention_ == "double07" assert candidates[0].mention == "double07"
assert 0.899 < candidates[0].prior_prob < 0.901 assert 0.899 < candidates[0].prior_prob < 0.901
assert candidates[1].entity_ == "Q17" assert candidates[1].entity_id_str == "Q17"
assert 1.99 < candidates[1].entity_freq < 2.01 assert 1.99 < candidates[1].entity_freq < 2.01
assert candidates[1].entity_vector == [7, 1, 0] assert candidates[1].entity_vector == [7, 1, 0]
assert candidates[1].mention_ == "double07" assert candidates[1].mention == "double07"
assert 0.099 < candidates[1].prior_prob < 0.101 assert 0.099 < candidates[1].prior_prob < 0.101