mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-09 23:53:10 +03:00
Refacor Candidate attributes and their usage.
This commit is contained in:
parent
94e57d0ed5
commit
38dce966e5
|
@ -14,41 +14,46 @@ class Candidate(abc.ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mention: str,
|
mention: str,
|
||||||
entity_id: int,
|
entity_id: Union[str, int],
|
||||||
entity_name: str,
|
|
||||||
entity_vector: List[float],
|
entity_vector: List[float],
|
||||||
prior_prob: float,
|
prior_prob: float,
|
||||||
):
|
):
|
||||||
"""Initializes properties of `Candidate` instance.
|
"""Initializes properties of `Candidate` instance.
|
||||||
mention (str): Mention text for this candidate.
|
mention (str): Mention text for this candidate.
|
||||||
entity_id (int): Unique entity ID.
|
entity_id (Union[str, int]): Unique entity ID.
|
||||||
entity_name (str): Entity name.
|
|
||||||
entity_vector (List[float]): Entity embedding.
|
entity_vector (List[float]): Entity embedding.
|
||||||
prior_prob (float): Prior probability of entity for this mention - i.e. the probability that, independent of
|
prior_prob (float): Prior probability of entity for this mention - i.e. the probability that, independent of
|
||||||
the context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In
|
the context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In
|
||||||
cases in which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus
|
cases in which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus
|
||||||
doesn't) it might be better to eschew this information and always supply the same value.
|
doesn't) it might be better to eschew this information and always supply the same value.
|
||||||
"""
|
"""
|
||||||
self._mention_ = mention
|
self._mention = mention
|
||||||
self._entity = entity_id
|
self._entity_id = entity_id
|
||||||
self._entity_ = entity_name
|
# Note that hashing an int value yields the same int value.
|
||||||
|
self._entity_id_hash = hash(entity_id)
|
||||||
self._entity_vector = entity_vector
|
self._entity_vector = entity_vector
|
||||||
self._prior_prob = prior_prob
|
self._prior_prob = prior_prob
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def entity(self) -> int:
|
def entity_id(self) -> Union[str, int]:
|
||||||
"""RETURNS (int): Unique entity ID."""
|
"""RETURNS (Union[str, int]): Unique entity ID."""
|
||||||
return self._entity
|
return self._entity_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def entity_(self) -> str:
|
def entity_id_int(self) -> int:
|
||||||
"""RETURNS (int): Entity name."""
|
"""RETURNS (int): Numerical representation of entity ID (if entity ID is numerical, this is just the entity ID,
|
||||||
return self._entity_
|
otherwise the hash of the entity ID string)."""
|
||||||
|
return self._entity_id_hash
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mention_(self) -> str:
|
def entity_id_str(self) -> str:
|
||||||
|
"""RETURNS (str): String representation of entity ID."""
|
||||||
|
return str(self._entity_id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mention(self) -> str:
|
||||||
"""RETURNS (str): Mention."""
|
"""RETURNS (str): Mention."""
|
||||||
return self._mention_
|
return self._mention
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def entity_vector(self) -> List[float]:
|
def entity_vector(self) -> List[float]:
|
||||||
|
@ -66,49 +71,40 @@ class InMemoryCandidate(Candidate):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
retrieve_string_from_hash: Callable[[int], str],
|
hash_to_str: Callable[[int], str],
|
||||||
entity_hash: int,
|
entity_id: int,
|
||||||
entity_freq: int,
|
mention: str,
|
||||||
entity_vector: List[float],
|
entity_vector: List[float],
|
||||||
mention_hash: int,
|
|
||||||
prior_prob: float,
|
prior_prob: float,
|
||||||
|
entity_freq: int
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
retrieve_string_from_hash (Callable[[int], str]): Callable retrieving entity name from provided entity/vocab
|
hash_to_str (Callable[[int], str]): Callable retrieving entity name from provided entity/vocab hash.
|
||||||
hash.
|
entity_id (str): Entity ID as hash that can be looked up with InMemoryKB.vocab.strings.__getitem__().
|
||||||
entity_hash (str): Hashed entity name /ID.
|
|
||||||
entity_freq (int): Entity frequency in KB corpus.
|
entity_freq (int): Entity frequency in KB corpus.
|
||||||
entity_vector (List[float]): Entity embedding.
|
entity_vector (List[float]): Entity embedding.
|
||||||
mention_hash (int): Hashed mention.
|
mention (str): Mention.
|
||||||
prior_prob (float): Prior probability of entity for this mention - i.e. the probability that, independent of
|
prior_prob (float): Prior probability of entity for this mention - i.e. the probability that, independent of
|
||||||
the context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In
|
the context, this mention resolves to this entity_id in the corpus used to build the knowledge base. In
|
||||||
cases in which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus
|
cases in which this isn't always possible (e.g.: the corpus to analyse contains mentions that the KB corpus
|
||||||
doesn't) it might be better to eschew this information and always supply the same value.
|
doesn't) it might be better to eschew this information and always supply the same value.
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
mention=retrieve_string_from_hash(mention_hash),
|
mention=mention,
|
||||||
entity_id=entity_hash,
|
entity_id=entity_id,
|
||||||
entity_name=retrieve_string_from_hash(entity_hash),
|
|
||||||
entity_vector=entity_vector,
|
entity_vector=entity_vector,
|
||||||
prior_prob=prior_prob,
|
prior_prob=prior_prob,
|
||||||
)
|
)
|
||||||
self._retrieve_string_from_hash = retrieve_string_from_hash
|
self._hash_to_str = hash_to_str
|
||||||
self._entity = entity_hash
|
|
||||||
self._entity_freq = entity_freq
|
self._entity_freq = entity_freq
|
||||||
self._mention = mention_hash
|
self._entity_id_str = self._hash_to_str(self._entity_id)
|
||||||
self._prior_prob = prior_prob
|
|
||||||
|
|
||||||
@property
|
|
||||||
def entity(self) -> int:
|
|
||||||
"""RETURNS (int): hash of the entity_id's KB ID/name"""
|
|
||||||
return self._entity
|
|
||||||
|
|
||||||
@property
|
|
||||||
def mention(self) -> int:
|
|
||||||
"""RETURNS (int): Mention hash."""
|
|
||||||
return self._mention
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def entity_freq(self) -> float:
|
def entity_freq(self) -> float:
|
||||||
"""RETURNS (float): Relative entity frequency."""
|
"""RETURNS (float): Relative entity frequency."""
|
||||||
return self._entity_freq
|
return self._entity_freq
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entity_id_str(self) -> str:
|
||||||
|
"""RETURNS (str): String representation of entity ID."""
|
||||||
|
return self._entity_id_str
|
||||||
|
|
|
@ -240,12 +240,12 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
||||||
|
|
||||||
return [
|
return [
|
||||||
InMemoryCandidate(
|
InMemoryCandidate(
|
||||||
retrieve_string_from_hash=self.vocab.strings.__getitem__,
|
hash_to_str=self.vocab.strings.__getitem__,
|
||||||
entity_hash=self._entries[entry_index].entity_hash,
|
entity_id=self._entries[entry_index].entity_hash,
|
||||||
entity_freq=self._entries[entry_index].freq,
|
mention=alias,
|
||||||
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
|
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
|
||||||
mention_hash=alias_hash,
|
prior_prob=prior_prob,
|
||||||
prior_prob=prior_prob
|
entity_freq=self._entries[entry_index].freq
|
||||||
)
|
)
|
||||||
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||||
if entry_index != 0
|
if entry_index != 0
|
||||||
|
|
|
@ -535,12 +535,12 @@ class EntityLinker(TrainablePipe):
|
||||||
)
|
)
|
||||||
elif len(candidates) == 1 and self.threshold is None:
|
elif len(candidates) == 1 and self.threshold is None:
|
||||||
# shortcut for efficiency reasons: take the 1 candidate
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
final_kb_ids.append(candidates[0].entity_)
|
final_kb_ids.append(candidates[0].entity_id_str)
|
||||||
self._add_activations(
|
self._add_activations(
|
||||||
doc_scores=doc_scores,
|
doc_scores=doc_scores,
|
||||||
doc_ents=doc_ents,
|
doc_ents=doc_ents,
|
||||||
scores=[1.0],
|
scores=[1.0],
|
||||||
ents=[candidates[0].entity],
|
ents=[candidates[0].entity_id_int],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
|
@ -570,7 +570,7 @@ class EntityLinker(TrainablePipe):
|
||||||
raise ValueError(Errors.E161)
|
raise ValueError(Errors.E161)
|
||||||
scores = prior_probs + sims - (prior_probs * sims)
|
scores = prior_probs + sims - (prior_probs * sims)
|
||||||
final_kb_ids.append(
|
final_kb_ids.append(
|
||||||
candidates[scores.argmax().item()].entity_
|
candidates[scores.argmax().item()].entity_id_str
|
||||||
if self.threshold is None
|
if self.threshold is None
|
||||||
or scores.max() >= self.threshold
|
or scores.max() >= self.threshold
|
||||||
else EntityLinker.NIL
|
else EntityLinker.NIL
|
||||||
|
@ -579,7 +579,7 @@ class EntityLinker(TrainablePipe):
|
||||||
doc_scores=doc_scores,
|
doc_scores=doc_scores,
|
||||||
doc_ents=doc_ents,
|
doc_ents=doc_ents,
|
||||||
scores=scores,
|
scores=scores,
|
||||||
ents=[c.entity for c in candidates],
|
ents=[c.entity_id_int for c in candidates],
|
||||||
)
|
)
|
||||||
self._add_doc_activations(
|
self._add_doc_activations(
|
||||||
docs_scores=docs_scores,
|
docs_scores=docs_scores,
|
||||||
|
|
|
@ -468,8 +468,8 @@ def test_candidate_generation(nlp):
|
||||||
assert len(get_candidates(mykb, shrubbery_ent)) == 0
|
assert len(get_candidates(mykb, shrubbery_ent)) == 0
|
||||||
|
|
||||||
# test the content of the candidates
|
# test the content of the candidates
|
||||||
assert get_candidates(mykb, adam_ent)[0].entity_ == "Q2"
|
assert get_candidates(mykb, adam_ent)[0].entity_id_str == "Q2"
|
||||||
assert get_candidates(mykb, adam_ent)[0].mention_ == "adam"
|
assert get_candidates(mykb, adam_ent)[0].mention == "adam"
|
||||||
assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
|
assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
|
||||||
assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
|
assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
|
||||||
|
|
||||||
|
@ -560,10 +560,9 @@ def test_vocab_serialization(nlp):
|
||||||
|
|
||||||
candidates = mykb._get_alias_candidates("adam")
|
candidates = mykb._get_alias_candidates("adam")
|
||||||
assert len(candidates) == 1
|
assert len(candidates) == 1
|
||||||
assert candidates[0].entity == q2_hash
|
assert candidates[0].entity_id_int == q2_hash
|
||||||
assert candidates[0].entity_ == "Q2"
|
assert candidates[0].entity_id_str == "Q2"
|
||||||
assert candidates[0].mention == adam_hash
|
assert candidates[0].mention == "adam"
|
||||||
assert candidates[0].mention_ == "adam"
|
|
||||||
|
|
||||||
with make_tempdir() as d:
|
with make_tempdir() as d:
|
||||||
mykb.to_disk(d / "kb")
|
mykb.to_disk(d / "kb")
|
||||||
|
@ -572,10 +571,9 @@ def test_vocab_serialization(nlp):
|
||||||
|
|
||||||
candidates = kb_new_vocab._get_alias_candidates("adam")
|
candidates = kb_new_vocab._get_alias_candidates("adam")
|
||||||
assert len(candidates) == 1
|
assert len(candidates) == 1
|
||||||
assert candidates[0].entity == q2_hash
|
assert candidates[0].entity_id_int == q2_hash
|
||||||
assert candidates[0].entity_ == "Q2"
|
assert candidates[0].entity_id_str == "Q2"
|
||||||
assert candidates[0].mention == adam_hash
|
assert candidates[0].mention == "adam"
|
||||||
assert candidates[0].mention_ == "adam"
|
|
||||||
|
|
||||||
assert kb_new_vocab.get_vector("Q2") == [2]
|
assert kb_new_vocab.get_vector("Q2") == [2]
|
||||||
assert_almost_equal(kb_new_vocab.get_prior_prob("Q2", "douglas"), 0.4)
|
assert_almost_equal(kb_new_vocab.get_prior_prob("Q2", "douglas"), 0.4)
|
||||||
|
|
|
@ -63,19 +63,19 @@ def _check_kb(kb):
|
||||||
assert alias_string not in kb.get_alias_strings()
|
assert alias_string not in kb.get_alias_strings()
|
||||||
|
|
||||||
# check candidates & probabilities
|
# check candidates & probabilities
|
||||||
candidates = sorted(kb._get_alias_candidates("double07"), key=lambda x: x.entity_)
|
candidates = sorted(kb._get_alias_candidates("double07"), key=lambda x: x.entity_id_str)
|
||||||
assert len(candidates) == 2
|
assert len(candidates) == 2
|
||||||
|
|
||||||
assert candidates[0].entity_ == "Q007"
|
assert candidates[0].entity_id_str == "Q007"
|
||||||
assert 6.999 < candidates[0].entity_freq < 7.01
|
assert 6.999 < candidates[0].entity_freq < 7.01
|
||||||
assert candidates[0].entity_vector == [0, 0, 7]
|
assert candidates[0].entity_vector == [0, 0, 7]
|
||||||
assert candidates[0].mention_ == "double07"
|
assert candidates[0].mention == "double07"
|
||||||
assert 0.899 < candidates[0].prior_prob < 0.901
|
assert 0.899 < candidates[0].prior_prob < 0.901
|
||||||
|
|
||||||
assert candidates[1].entity_ == "Q17"
|
assert candidates[1].entity_id_str == "Q17"
|
||||||
assert 1.99 < candidates[1].entity_freq < 2.01
|
assert 1.99 < candidates[1].entity_freq < 2.01
|
||||||
assert candidates[1].entity_vector == [7, 1, 0]
|
assert candidates[1].entity_vector == [7, 1, 0]
|
||||||
assert candidates[1].mention_ == "double07"
|
assert candidates[1].mention == "double07"
|
||||||
assert 0.099 < candidates[1].prior_prob < 0.101
|
assert 0.099 < candidates[1].prior_prob < 0.101
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user