hash the entity name

This commit is contained in:
svlandeg 2019-03-15 15:00:53 +01:00
parent 839dafa104
commit feb71e15fd
2 changed files with 11 additions and 9 deletions

View File

@ -2,8 +2,9 @@
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libc.stdint cimport int32_t from libc.stdint cimport int32_t, int64_t
from spacy.typedefs cimport attr_t from .typedefs cimport attr_t, hash_t
from .strings cimport hash_string
# Internal struct, for storage and disambiguation. This isn't what we return # Internal struct, for storage and disambiguation. This isn't what we return
@ -70,21 +71,20 @@ cdef class KnowledgeBase:
def __len__(self): def __len__(self):
return self._entries.size() return self._entries.size()
def add(self, name, float prob, vectors=None, features=None, aliases=None): def add_entity(self, name, float prob, vectors=None, features=None, aliases=None):
# TODO: more friendly check for non-unique name # TODO: more friendly check for non-unique name
if name in self: if name in self:
return return
# TODO: convert name to hash cdef hash_t key = hash_string(name)
cdef attr_t orth = get_string_name(name) self.c_add_entity(key, prob, self._vectors_table.get_pointer(vectors),
self.c_add(orth, prob, self._vectors_table.get_pointer(vectors),
self._features_table.get(features)) self._features_table.get(features))
# TODO: hash the aliases? # TODO: hash the aliases?
for alias, prob_alias in aliases: for alias, prob_alias in aliases:
self._aliases_table.add(alias, orth, prob_alias) self._aliases_table.add(alias, key, prob_alias)
cdef void c_add(self, attr_t orth, float prob, const int32_t* vector_rows, cdef void c_add_entity(self, hash_t key, float prob, const int32_t* vector_rows,
int feats_row) nogil: int feats_row) nogil:
"""Add an entry to the knowledge base.""" """Add an entry to the knowledge base."""
# This is what we'll map the orth to. It's where the entry will sit # This is what we'll map the orth to. It's where the entry will sit
@ -96,5 +96,5 @@ cdef class KnowledgeBase:
feats_row=feats_row, feats_row=feats_row,
prob=prob prob=prob
)) ))
self._index[orth] = index self._index[key] = index
return index return index

View File

@ -661,6 +661,8 @@ cdef class Span:
"""RETURNS (unicode): The named entity's KB ID.""" """RETURNS (unicode): The named entity's KB ID."""
def __get__(self): def __get__(self):
return self.doc.vocab.strings[self.kb_id] return self.doc.vocab.strings[self.kb_id]
# TODO: custom error msg like for label_
def __set__(self, unicode kb_id_): def __set__(self, unicode kb_id_):
raise NotImplementedError(TempErrors.T007.format(attr='kb_id_')) raise NotImplementedError(TempErrors.T007.format(attr='kb_id_'))