Simplify interface for int/str representations.

This commit is contained in:
Raphael Mitsch 2023-03-07 14:35:38 +01:00
parent 0c63940407
commit cea58ade89
4 changed files with 15 additions and 20 deletions

View File

@ -37,18 +37,13 @@ class Candidate(abc.ABC):
self._prior_prob = prior_prob self._prior_prob = prior_prob
@property @property
def entity_id(self) -> Union[str, int]: def entity_id(self) -> int:
"""RETURNS (Union[str, int]): Unique entity ID."""
return self._entity_id
@property
def entity_id_int(self) -> int:
"""RETURNS (int): Numerical representation of entity ID (if entity ID is numerical, this is just the entity ID, """RETURNS (int): Numerical representation of entity ID (if entity ID is numerical, this is just the entity ID,
otherwise the hash of the entity ID string).""" otherwise the hash of the entity ID string)."""
return self._entity_id_hash return self._entity_id_hash
@property @property
def entity_id_str(self) -> str: def entity_id_(self) -> str:
"""RETURNS (str): String representation of entity ID.""" """RETURNS (str): String representation of entity ID."""
return str(self._entity_id) return str(self._entity_id)
@ -111,6 +106,6 @@ class InMemoryCandidate(Candidate):
return self._entity_freq return self._entity_freq
@property @property
def entity_id_str(self) -> str: def entity_id_(self) -> str:
"""RETURNS (str): String representation of entity ID.""" """RETURNS (str): String representation of entity ID."""
return self._entity_id_str return self._entity_id_str

View File

@ -522,12 +522,12 @@ class EntityLinker(TrainablePipe):
) )
elif len(candidates) == 1 and self.threshold is None: elif len(candidates) == 1 and self.threshold is None:
# shortcut for efficiency reasons: take the 1 candidate # shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_id_str) final_kb_ids.append(candidates[0].entity_id_)
self._add_activations( self._add_activations(
doc_scores=doc_scores, doc_scores=doc_scores,
doc_ents=doc_ents, doc_ents=doc_ents,
scores=[1.0], scores=[1.0],
ents=[candidates[0].entity_id_int], ents=[candidates[0].entity_id],
) )
else: else:
random.shuffle(candidates) random.shuffle(candidates)
@ -557,7 +557,7 @@ class EntityLinker(TrainablePipe):
raise ValueError(Errors.E161) raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs * sims) scores = prior_probs + sims - (prior_probs * sims)
final_kb_ids.append( final_kb_ids.append(
candidates[scores.argmax().item()].entity_id_str candidates[scores.argmax().item()].entity_id_
if self.threshold is None if self.threshold is None
or scores.max() >= self.threshold or scores.max() >= self.threshold
else EntityLinker.NIL else EntityLinker.NIL
@ -566,7 +566,7 @@ class EntityLinker(TrainablePipe):
doc_scores=doc_scores, doc_scores=doc_scores,
doc_ents=doc_ents, doc_ents=doc_ents,
scores=scores, scores=scores,
ents=[c.entity_id_int for c in candidates], ents=[c.entity_id for c in candidates],
) )
self._add_doc_activations( self._add_doc_activations(
docs_scores=docs_scores, docs_scores=docs_scores,

View File

@ -471,7 +471,7 @@ def test_candidate_generation(nlp):
assert len(get_candidates(mykb, shrubbery_ent)) == 0 assert len(get_candidates(mykb, shrubbery_ent)) == 0
# test the content of the candidates # test the content of the candidates
assert get_candidates(mykb, adam_ent)[0].entity_id_str == "Q2" assert get_candidates(mykb, adam_ent)[0].entity_id_ == "Q2"
assert get_candidates(mykb, adam_ent)[0].mention == "adam" assert get_candidates(mykb, adam_ent)[0].mention == "adam"
assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12) assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9) assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
@ -563,8 +563,8 @@ def test_vocab_serialization(nlp):
candidates = mykb._get_alias_candidates("adam") candidates = mykb._get_alias_candidates("adam")
assert len(candidates) == 1 assert len(candidates) == 1
assert candidates[0].entity_id_int == q2_hash assert candidates[0].entity_id == q2_hash
assert candidates[0].entity_id_str == "Q2" assert candidates[0].entity_id_ == "Q2"
assert candidates[0].mention == "adam" assert candidates[0].mention == "adam"
with make_tempdir() as d: with make_tempdir() as d:
@ -574,8 +574,8 @@ def test_vocab_serialization(nlp):
candidates = kb_new_vocab._get_alias_candidates("adam") candidates = kb_new_vocab._get_alias_candidates("adam")
assert len(candidates) == 1 assert len(candidates) == 1
assert candidates[0].entity_id_int == q2_hash assert candidates[0].entity_id == q2_hash
assert candidates[0].entity_id_str == "Q2" assert candidates[0].entity_id_ == "Q2"
assert candidates[0].mention == "adam" assert candidates[0].mention == "adam"
assert kb_new_vocab.get_vector("Q2") == [2] assert kb_new_vocab.get_vector("Q2") == [2]

View File

@ -67,17 +67,17 @@ def _check_kb(kb):
# check candidates & probabilities # check candidates & probabilities
candidates = sorted( candidates = sorted(
kb._get_alias_candidates("double07"), key=lambda x: x.entity_id_str kb._get_alias_candidates("double07"), key=lambda x: x.entity_id_
) )
assert len(candidates) == 2 assert len(candidates) == 2
assert candidates[0].entity_id_str == "Q007" assert candidates[0].entity_id_ == "Q007"
assert 6.999 < candidates[0].entity_freq < 7.01 assert 6.999 < candidates[0].entity_freq < 7.01
assert candidates[0].entity_vector == [0, 0, 7] assert candidates[0].entity_vector == [0, 0, 7]
assert candidates[0].mention == "double07" assert candidates[0].mention == "double07"
assert 0.899 < candidates[0].prior_prob < 0.901 assert 0.899 < candidates[0].prior_prob < 0.901
assert candidates[1].entity_id_str == "Q17" assert candidates[1].entity_id_ == "Q17"
assert 1.99 < candidates[1].entity_freq < 2.01 assert 1.99 < candidates[1].entity_freq < 2.01
assert candidates[1].entity_vector == [7, 1, 0] assert candidates[1].entity_vector == [7, 1, 0]
assert candidates[1].mention == "double07" assert candidates[1].mention == "double07"