Simplify interface for int/str representations.

This commit is contained in:
Raphael Mitsch 2023-03-07 14:35:38 +01:00
parent 0c63940407
commit cea58ade89
4 changed files with 15 additions and 20 deletions

View File

@ -37,18 +37,13 @@ class Candidate(abc.ABC):
self._prior_prob = prior_prob
@property
def entity_id(self) -> Union[str, int]:
"""RETURNS (Union[str, int]): Unique entity ID."""
return self._entity_id
@property
def entity_id_int(self) -> int:
def entity_id(self) -> int:
"""RETURNS (int): Numerical representation of entity ID (if entity ID is numerical, this is just the entity ID,
otherwise the hash of the entity ID string)."""
return self._entity_id_hash
@property
def entity_id_str(self) -> str:
def entity_id_(self) -> str:
"""RETURNS (str): String representation of entity ID."""
return str(self._entity_id)
@ -111,6 +106,6 @@ class InMemoryCandidate(Candidate):
return self._entity_freq
@property
def entity_id_str(self) -> str:
def entity_id_(self) -> str:
"""RETURNS (str): String representation of entity ID."""
return self._entity_id_str

View File

@ -522,12 +522,12 @@ class EntityLinker(TrainablePipe):
)
elif len(candidates) == 1 and self.threshold is None:
# shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_id_str)
final_kb_ids.append(candidates[0].entity_id_)
self._add_activations(
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=[1.0],
ents=[candidates[0].entity_id_int],
ents=[candidates[0].entity_id],
)
else:
random.shuffle(candidates)
@ -557,7 +557,7 @@ class EntityLinker(TrainablePipe):
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs * sims)
final_kb_ids.append(
candidates[scores.argmax().item()].entity_id_str
candidates[scores.argmax().item()].entity_id_
if self.threshold is None
or scores.max() >= self.threshold
else EntityLinker.NIL
@ -566,7 +566,7 @@ class EntityLinker(TrainablePipe):
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=scores,
ents=[c.entity_id_int for c in candidates],
ents=[c.entity_id for c in candidates],
)
self._add_doc_activations(
docs_scores=docs_scores,

View File

@ -471,7 +471,7 @@ def test_candidate_generation(nlp):
assert len(get_candidates(mykb, shrubbery_ent)) == 0
# test the content of the candidates
assert get_candidates(mykb, adam_ent)[0].entity_id_str == "Q2"
assert get_candidates(mykb, adam_ent)[0].entity_id_ == "Q2"
assert get_candidates(mykb, adam_ent)[0].mention == "adam"
assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
@ -563,8 +563,8 @@ def test_vocab_serialization(nlp):
candidates = mykb._get_alias_candidates("adam")
assert len(candidates) == 1
assert candidates[0].entity_id_int == q2_hash
assert candidates[0].entity_id_str == "Q2"
assert candidates[0].entity_id == q2_hash
assert candidates[0].entity_id_ == "Q2"
assert candidates[0].mention == "adam"
with make_tempdir() as d:
@ -574,8 +574,8 @@ def test_vocab_serialization(nlp):
candidates = kb_new_vocab._get_alias_candidates("adam")
assert len(candidates) == 1
assert candidates[0].entity_id_int == q2_hash
assert candidates[0].entity_id_str == "Q2"
assert candidates[0].entity_id == q2_hash
assert candidates[0].entity_id_ == "Q2"
assert candidates[0].mention == "adam"
assert kb_new_vocab.get_vector("Q2") == [2]

View File

@ -67,17 +67,17 @@ def _check_kb(kb):
# check candidates & probabilities
candidates = sorted(
kb._get_alias_candidates("double07"), key=lambda x: x.entity_id_str
kb._get_alias_candidates("double07"), key=lambda x: x.entity_id_
)
assert len(candidates) == 2
assert candidates[0].entity_id_str == "Q007"
assert candidates[0].entity_id_ == "Q007"
assert 6.999 < candidates[0].entity_freq < 7.01
assert candidates[0].entity_vector == [0, 0, 7]
assert candidates[0].mention == "double07"
assert 0.899 < candidates[0].prior_prob < 0.901
assert candidates[1].entity_id_str == "Q17"
assert candidates[1].entity_id_ == "Q17"
assert 1.99 < candidates[1].entity_freq < 2.01
assert candidates[1].entity_vector == [7, 1, 0]
assert candidates[1].mention == "double07"