property getters and keep track of KB internally

This commit is contained in:
svlandeg 2019-03-21 13:26:12 +01:00
parent 98ae77a682
commit 1289cd6e8f
3 changed files with 46 additions and 30 deletions

View File

@ -46,6 +46,7 @@ cdef struct _AliasC:
# TODO: document # TODO: document
cdef class Candidate: cdef class Candidate:
cdef readonly KnowledgeBase kb
cdef hash_t entity_hash cdef hash_t entity_hash
cdef hash_t alias_hash cdef hash_t alias_hash
cdef float prior_prob cdef float prior_prob

View File

@ -5,16 +5,31 @@ from spacy.errors import user_warning
cdef class Candidate: cdef class Candidate:
def __init__(self, entity_hash, alias_hash, prior_prob): def __init__(self, KnowledgeBase kb, entity_hash, alias_hash, prior_prob):
self.kb = kb
self.entity_hash = entity_hash self.entity_hash = entity_hash
self.alias_hash = alias_hash self.alias_hash = alias_hash
self.prior_prob = prior_prob self.prior_prob = prior_prob
def get_entity_name(self, KnowledgeBase kb): property kb_id_:
return kb.strings[self.entity_hash] """RETURNS (unicode): ID of this entity in the KB"""
def __get__(self):
return self.kb.strings[self.entity_hash]
def get_alias_name(self, KnowledgeBase kb): property kb_id:
return kb.strings[self.alias_hash] """RETURNS (uint64): hash of the entity's KB ID"""
def __get__(self):
return self.entity_hash
property alias_:
"""RETURNS (unicode): ID of the original alias"""
def __get__(self):
return self.kb.strings[self.alias_hash]
property alias:
"""RETURNS (uint64): hash of the alias"""
def __get__(self):
return self.alias_hash
property prior_prob: property prior_prob:
def __get__(self): def __get__(self):
@ -40,6 +55,10 @@ cdef class KnowledgeBase:
return self._aliases_table.size() - 1 # not counting dummy element on index 0 return self._aliases_table.size() - 1 # not counting dummy element on index 0
def add_entity(self, unicode entity_id, float prob, vectors=None, features=None): def add_entity(self, unicode entity_id, float prob, vectors=None, features=None):
"""
Add an entity to the KB.
Return the hash of the entity ID at the end
"""
cdef hash_t id_hash = self.strings.add(entity_id) cdef hash_t id_hash = self.strings.add(entity_id)
# Return if this entity was added before # Return if this entity was added before
@ -52,8 +71,13 @@ cdef class KnowledgeBase:
# TODO self._vectors_table.get_pointer(vectors), # TODO self._vectors_table.get_pointer(vectors),
# self._features_table.get(features)) # self._features_table.get(features))
return id_hash
def add_alias(self, unicode alias, entities, probabilities): def add_alias(self, unicode alias, entities, probabilities):
"""For a given alias, add its potential entities and prior probabilies to the KB.""" """
For a given alias, add its potential entities and prior probabilies to the KB.
Return the alias_hash at the end
"""
# Throw an error if the length of entities and probabilities are not the same # Throw an error if the length of entities and probabilities are not the same
if not len(entities) == len(probabilities): if not len(entities) == len(probabilities):
@ -91,13 +115,16 @@ cdef class KnowledgeBase:
self.c_add_aliases(alias_hash=alias_hash, entry_indices=entry_indices, probs=probs) self.c_add_aliases(alias_hash=alias_hash, entry_indices=entry_indices, probs=probs)
return alias_hash
def get_candidates(self, unicode alias): def get_candidates(self, unicode alias):
cdef hash_t alias_hash = self.strings[alias] cdef hash_t alias_hash = self.strings[alias]
alias_index = <int64_t>self._alias_index.get(alias_hash) alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index] alias_entry = self._aliases_table[alias_index]
return [Candidate(entity_hash=self._entries[entry_index].entity_hash, return [Candidate(kb=self,
entity_hash=self._entries[entry_index].entity_hash,
alias_hash=alias_hash, alias_hash=alias_hash,
prior_prob=prob) prior_prob=prob)
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)] for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)]

View File

@ -38,28 +38,16 @@ def create_kb():
print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases()) print("kb size", len(mykb), mykb.get_size_entities(), mykb.get_size_aliases())
print("candidates for", alias1) for alias in [alias1, alias2, alias3]:
candidates1 = mykb.get_candidates(alias1) print()
for candidate in candidates1: print("candidates for", alias)
candidates = mykb.get_candidates(alias)
for candidate in candidates:
print(" candidate") print(" candidate")
print(" name", candidate.get_entity_name(mykb)) print(" kb_id", candidate.kb_id)
print(" alias", candidate.get_alias_name(mykb)) print(" kb_id_", candidate.kb_id_)
print(" prior_prob", candidate.prior_prob) print(" alias", candidate.alias)
print(" alias_", candidate.alias_)
print("candidates for", alias2)
candidates2 = mykb.get_candidates(alias2)
for candidate in candidates2:
print(" candidate")
print(" name", candidate.get_entity_name(mykb))
print(" alias", candidate.get_alias_name(mykb))
print(" prior_prob", candidate.prior_prob)
print("candidates for", alias3)
candidates3 = mykb.get_candidates(alias3)
for candidate in candidates3:
print(" candidate")
print(" name", candidate.get_entity_name(mykb))
print(" alias", candidate.get_alias_name(mykb))
print(" prior_prob", candidate.prior_prob) print(" prior_prob", candidate.prior_prob)