diff --git a/spacy/kb.pxd b/spacy/kb.pxd index e715cad88..9d9a21a8c 100644 --- a/spacy/kb.pxd +++ b/spacy/kb.pxd @@ -3,8 +3,7 @@ from cymem.cymem cimport Pool from preshed.maps cimport PreshMap from libcpp.vector cimport vector from libc.stdint cimport int32_t, int64_t -from .typedefs cimport attr_t, hash_t -from .strings cimport hash_string +from .typedefs cimport hash_t # Internal struct, for storage and disambiguation. This isn't what we return @@ -68,26 +67,10 @@ cdef class KnowledgeBase: # efficient. cdef object _aliases_table - def __len__(self): - return self._entries.size() - - def add_entity(self, name, float prob, vectors=None, features=None, aliases=None): - # TODO: more friendly check for non-unique name - if name in self: - return - - cdef hash_t key = hash_string(name) - self.c_add_entity(key, prob, self._vectors_table.get_pointer(vectors), - self._features_table.get(features)) - - # TODO: hash the aliases? - for alias, prob_alias in aliases: - self._aliases_table.add(alias, key, prob_alias) - cdef void c_add_entity(self, hash_t key, float prob, const int32_t* vector_rows, int feats_row) nogil: """Add an entry to the knowledge base.""" - # This is what we'll map the orth to. It's where the entry will sit + # This is what we'll map the hash key to. It's where the entry will sit # in the vector of entries, so we can get it later. cdef int64_t index = self._entries.size() self._entries.push_back( diff --git a/spacy/kb.pyx b/spacy/kb.pyx new file mode 100644 index 000000000..ce76f2fc4 --- /dev/null +++ b/spacy/kb.pyx @@ -0,0 +1,27 @@ +from .strings cimport hash_string + + +cdef class KnowledgeBase: + def __len__(self): + return self._entries.size() + + def add_entity(self, name, float prob, vectors=None, features=None, aliases=None): + # TODO: more friendly check for non-unique name + if name in self: + return + + cdef hash_t name_hash = hash_string(name) + self.c_add_entity(name_hash, prob, self._vectors_table.get_pointer(vectors), + self._features_table.get(features)) + + def add_alias(self, alias, entities, probabilities): + """For a given alias, add its potential entities and prior probabilies to the KB.""" + cdef hash_t alias_hash = hash_string(alias) + + # TODO: check len(entities) == len(probabilities) + for entity, prob in zip(entities, probabilities): + cdef hash_t entity_hash = hash_string(entity) + cdef int64_t entity_index = self._index[entity_hash] + # TODO: check that entity is already in this KB (entity_index is OK) + self._aliases_table.add(alias_hash, entity_index, prob) +