Fix kb/kb_in_memory.pyx.

This commit is contained in:
Raphael Mitsch 2023-07-03 10:02:40 +02:00
parent 6c483971b7
commit a0bf50661b

View File

@ -1,5 +1,5 @@
# cython: infer_types=True, profile=True # cython: infer_types=True, profile=True
from typing import Any, Callable, Dict, Iterable, Union from typing import Any, Callable, Dict, Iterable
import srsly import srsly
@ -27,8 +27,9 @@ from .candidate import Candidate as Candidate
cdef class InMemoryLookupKB(KnowledgeBase): cdef class InMemoryLookupKB(KnowledgeBase):
"""An `InMemoryLookupKB` instance stores unique identifiers for entities and their textual aliases, """An `InMemoryLookupKB` instance stores unique identifiers for entities
to support entity linking of named entities to real-world concepts. and their textual aliases, to support entity linking of named entities to
real-world concepts.
DOCS: https://spacy.io/api/inmemorylookupkb DOCS: https://spacy.io/api/inmemorylookupkb
""" """
@ -71,7 +72,8 @@ cdef class InMemoryLookupKB(KnowledgeBase):
def add_entity(self, str entity, float freq, vector[float] entity_vector): def add_entity(self, str entity, float freq, vector[float] entity_vector):
""" """
Add an entity to the KB, optionally specifying its log probability based on corpus frequency Add an entity to the KB, optionally specifying its log probability
based on corpus frequency.
Return the hash of the entity ID/name at the end. Return the hash of the entity ID/name at the end.
""" """
cdef hash_t entity_hash = self.vocab.strings.add(entity) cdef hash_t entity_hash = self.vocab.strings.add(entity)
@ -83,14 +85,20 @@ cdef class InMemoryLookupKB(KnowledgeBase):
# Raise an error if the provided entity vector is not of the correct length # Raise an error if the provided entity vector is not of the correct length
if len(entity_vector) != self.entity_vector_length: if len(entity_vector) != self.entity_vector_length:
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length)) raise ValueError(
Errors.E141.format(
found=len(entity_vector), required=self.entity_vector_length
)
)
vector_index = self.c_add_vector(entity_vector=entity_vector) vector_index = self.c_add_vector(entity_vector=entity_vector)
new_index = self.c_add_entity(entity_hash=entity_hash, new_index = self.c_add_entity(
freq=freq, entity_hash=entity_hash,
vector_index=vector_index, freq=freq,
feats_row=-1) # Features table currently not implemented vector_index=vector_index,
feats_row=-1
) # Features table currently not implemented
self._entry_index[entity_hash] = new_index self._entry_index[entity_hash] = new_index
return entity_hash return entity_hash
@ -115,7 +123,12 @@ cdef class InMemoryLookupKB(KnowledgeBase):
else: else:
entity_vector = vector_list[i] entity_vector = vector_list[i]
if len(entity_vector) != self.entity_vector_length: if len(entity_vector) != self.entity_vector_length:
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length)) raise ValueError(
Errors.E141.format(
found=len(entity_vector),
required=self.entity_vector_length
)
)
entry.entity_hash = entity_hash entry.entity_hash = entity_hash
entry.freq = freq_list[i] entry.freq = freq_list[i]
@ -149,11 +162,15 @@ cdef class InMemoryLookupKB(KnowledgeBase):
previous_alias_nr = self.get_size_aliases() previous_alias_nr = self.get_size_aliases()
# Throw an error if the length of entities and probabilities are not the same # Throw an error if the length of entities and probabilities are not the same
if not len(entities) == len(probabilities): if not len(entities) == len(probabilities):
raise ValueError(Errors.E132.format(alias=alias, raise ValueError(
entities_length=len(entities), Errors.E132.format(
probabilities_length=len(probabilities))) alias=alias,
entities_length=len(entities),
probabilities_length=len(probabilities))
)
# Throw an error if the probabilities sum up to more than 1 (allow for some rounding errors) # Throw an error if the probabilities sum up to more than 1 (allow for
# some rounding errors)
prob_sum = sum(probabilities) prob_sum = sum(probabilities)
if prob_sum > 1.00001: if prob_sum > 1.00001:
raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum)) raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum))
@ -170,40 +187,47 @@ cdef class InMemoryLookupKB(KnowledgeBase):
for entity, prob in zip(entities, probabilities): for entity, prob in zip(entities, probabilities):
entity_hash = self.vocab.strings[entity] entity_hash = self.vocab.strings[entity]
if not entity_hash in self._entry_index: if entity_hash not in self._entry_index:
raise ValueError(Errors.E134.format(entity=entity)) raise ValueError(Errors.E134.format(entity=entity))
entry_index = <int64_t>self._entry_index.get(entity_hash) entry_index = <int64_t>self._entry_index.get(entity_hash)
entry_indices.push_back(int(entry_index)) entry_indices.push_back(int(entry_index))
probs.push_back(float(prob)) probs.push_back(float(prob))
new_index = self.c_add_aliases(alias_hash=alias_hash, entry_indices=entry_indices, probs=probs) new_index = self.c_add_aliases(
alias_hash=alias_hash, entry_indices=entry_indices, probs=probs
)
self._alias_index[alias_hash] = new_index self._alias_index[alias_hash] = new_index
if previous_alias_nr + 1 != self.get_size_aliases(): if previous_alias_nr + 1 != self.get_size_aliases():
raise RuntimeError(Errors.E891.format(alias=alias)) raise RuntimeError(Errors.E891.format(alias=alias))
return alias_hash return alias_hash
def append_alias(self, str alias, str entity, float prior_prob, ignore_warnings=False): def append_alias(
self, str alias, str entity, float prior_prob, ignore_warnings=False
):
""" """
For an alias already existing in the KB, extend its potential entities with one more. For an alias already existing in the KB, extend its potential entities
with one more.
Throw a warning if either the alias or the entity is unknown, Throw a warning if either the alias or the entity is unknown,
or when the combination is already previously recorded. or when the combination is already previously recorded.
Throw an error if this entity+prior prob would exceed the sum of 1. Throw an error if this entity+prior prob would exceed the sum of 1.
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one. For efficiency, it's best to use the method `add_alias` as much as
possible instead of this one.
""" """
# Check if the alias exists in the KB # Check if the alias exists in the KB
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index: if alias_hash not in self._alias_index:
raise ValueError(Errors.E176.format(alias=alias)) raise ValueError(Errors.E176.format(alias=alias))
# Check if the entity exists in the KB # Check if the entity exists in the KB
cdef hash_t entity_hash = self.vocab.strings[entity] cdef hash_t entity_hash = self.vocab.strings[entity]
if not entity_hash in self._entry_index: if entity_hash not in self._entry_index:
raise ValueError(Errors.E134.format(entity=entity)) raise ValueError(Errors.E134.format(entity=entity))
entry_index = <int64_t>self._entry_index.get(entity_hash) entry_index = <int64_t>self._entry_index.get(entity_hash)
# Throw an error if the prior probabilities (including the new one) sum up to more than 1 # Throw an error if the prior probabilities (including the new one)
# sum up to more than 1
alias_index = <int64_t>self._alias_index.get(alias_hash) alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index] alias_entry = self._aliases_table[alias_index]
current_sum = sum([p for p in alias_entry.probs]) current_sum = sum([p for p in alias_entry.probs])
@ -236,12 +260,13 @@ cdef class InMemoryLookupKB(KnowledgeBase):
def get_alias_candidates(self, str alias) -> Iterable[Candidate]: def get_alias_candidates(self, str alias) -> Iterable[Candidate]:
""" """
Return candidate entities for an alias. Each candidate defines the entity, the original alias, Return candidate entities for an alias. Each candidate defines the
and the prior probability of that alias resolving to that entity. entity, the original alias, and the prior probability of that alias
resolving to that entity.
If the alias is not known in the KB, and empty list is returned. If the alias is not known in the KB, and empty list is returned.
""" """
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index: if alias_hash not in self._alias_index:
return [] return []
alias_index = <int64_t>self._alias_index.get(alias_hash) alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index] alias_entry = self._aliases_table[alias_index]
@ -249,10 +274,14 @@ cdef class InMemoryLookupKB(KnowledgeBase):
return [Candidate(kb=self, return [Candidate(kb=self,
entity_hash=self._entries[entry_index].entity_hash, entity_hash=self._entries[entry_index].entity_hash,
entity_freq=self._entries[entry_index].freq, entity_freq=self._entries[entry_index].freq,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index], entity_vector=self._vectors_table[
self._entries[entry_index].vector_index
],
alias_hash=alias_hash, alias_hash=alias_hash,
prior_prob=prior_prob) prior_prob=prior_prob)
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs) for (entry_index, prior_prob) in zip(
alias_entry.entry_indices, alias_entry.probs
)
if entry_index != 0] if entry_index != 0]
def get_vector(self, str entity): def get_vector(self, str entity):
@ -266,8 +295,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
return self._vectors_table[self._entries[entry_index].vector_index] return self._vectors_table[self._entries[entry_index].vector_index]
def get_prior_prob(self, str entity, str alias): def get_prior_prob(self, str entity, str alias):
""" Return the prior probability of a given alias being linked to a given entity, """ Return the prior probability of a given alias being linked to a
or return 0.0 when this combination is not known in the knowledge base""" given entity, or return 0.0 when this combination is not known in the
knowledge base."""
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
cdef hash_t entity_hash = self.vocab.strings[entity] cdef hash_t entity_hash = self.vocab.strings[entity]
@ -278,7 +308,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
entry_index = self._entry_index[entity_hash] entry_index = self._entry_index[entity_hash]
alias_entry = self._aliases_table[alias_index] alias_entry = self._aliases_table[alias_index]
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs): for (entry_index, prior_prob) in zip(
alias_entry.entry_indices, alias_entry.probs
):
if self._entries[entry_index].entity_hash == entity_hash: if self._entries[entry_index].entity_hash == entity_hash:
return prior_prob return prior_prob
@ -288,13 +320,19 @@ cdef class InMemoryLookupKB(KnowledgeBase):
"""Serialize the current state to a binary string. """Serialize the current state to a binary string.
""" """
def serialize_header(): def serialize_header():
header = (self.get_size_entities(), self.get_size_aliases(), self.entity_vector_length) header = (
self.get_size_entities(),
self.get_size_aliases(),
self.entity_vector_length
)
return srsly.json_dumps(header) return srsly.json_dumps(header)
def serialize_entries(): def serialize_entries():
i = 1 i = 1
tuples = [] tuples = []
for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]): for entry_hash, entry_index in sorted(
self._entry_index.items(), key=lambda x: x[1]
):
entry = self._entries[entry_index] entry = self._entries[entry_index]
assert entry.entity_hash == entry_hash assert entry.entity_hash == entry_hash
assert entry_index == i assert entry_index == i
@ -307,7 +345,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
headers = [] headers = []
indices_lists = [] indices_lists = []
probs_lists = [] probs_lists = []
for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]): for alias_hash, alias_index in sorted(
self._alias_index.items(), key=lambda x: x[1]
):
alias = self._aliases_table[alias_index] alias = self._aliases_table[alias_index]
assert alias_index == i assert alias_index == i
candidate_length = len(alias.entry_indices) candidate_length = len(alias.entry_indices)
@ -365,7 +405,7 @@ cdef class InMemoryLookupKB(KnowledgeBase):
indices = srsly.json_loads(all_data[1]) indices = srsly.json_loads(all_data[1])
probs = srsly.json_loads(all_data[2]) probs = srsly.json_loads(all_data[2])
for header, indices, probs in zip(headers, indices, probs): for header, indices, probs in zip(headers, indices, probs):
alias_hash, candidate_length = header alias_hash, _candidate_length = header
alias.entry_indices = indices alias.entry_indices = indices
alias.probs = probs alias.probs = probs
self._aliases_table[i] = alias self._aliases_table[i] = alias
@ -414,10 +454,14 @@ cdef class InMemoryLookupKB(KnowledgeBase):
writer.write_vector_element(element) writer.write_vector_element(element)
i = i+1 i = i+1
# dumping the entry records in the order in which they are in the _entries vector. # dumping the entry records in the order in which they are in the
# index 0 is a dummy object not stored in the _entry_index and can be ignored. # _entries vector.
# index 0 is a dummy object not stored in the _entry_index and can
# be ignored.
i = 1 i = 1
for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]): for entry_hash, entry_index in sorted(
self._entry_index.items(), key=lambda x: x[1]
):
entry = self._entries[entry_index] entry = self._entries[entry_index]
assert entry.entity_hash == entry_hash assert entry.entity_hash == entry_hash
assert entry_index == i assert entry_index == i
@ -429,7 +473,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
# dumping the aliases in the order in which they are in the _alias_index vector. # dumping the aliases in the order in which they are in the _alias_index vector.
# index 0 is a dummy object not stored in the _aliases_table and can be ignored. # index 0 is a dummy object not stored in the _aliases_table and can be ignored.
i = 1 i = 1
for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]): for alias_hash, alias_index in sorted(
self._alias_index.items(), key=lambda x: x[1]
):
alias = self._aliases_table[alias_index] alias = self._aliases_table[alias_index]
assert alias_index == i assert alias_index == i
@ -535,7 +581,8 @@ cdef class Writer:
def __init__(self, path): def __init__(self, path):
assert isinstance(path, Path) assert isinstance(path, Path)
content = bytes(path) content = bytes(path)
cdef bytes bytes_loc = content.encode('utf8') if type(content) == str else content cdef bytes bytes_loc = content.encode('utf8') \
if type(content) == str else content
self._fp = fopen(<char*>bytes_loc, 'wb') self._fp = fopen(<char*>bytes_loc, 'wb')
if not self._fp: if not self._fp:
raise IOError(Errors.E146.format(path=path)) raise IOError(Errors.E146.format(path=path))
@ -545,14 +592,18 @@ cdef class Writer:
cdef size_t status = fclose(self._fp) cdef size_t status = fclose(self._fp)
assert status == 0 assert status == 0
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1: cdef int write_header(
self, int64_t nr_entries, int64_t entity_vector_length
) except -1:
self._write(&nr_entries, sizeof(nr_entries)) self._write(&nr_entries, sizeof(nr_entries))
self._write(&entity_vector_length, sizeof(entity_vector_length)) self._write(&entity_vector_length, sizeof(entity_vector_length))
cdef int write_vector_element(self, float element) except -1: cdef int write_vector_element(self, float element) except -1:
self._write(&element, sizeof(element)) self._write(&element, sizeof(element))
cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1: cdef int write_entry(
self, hash_t entry_hash, float entry_freq, int32_t vector_index
) except -1:
self._write(&entry_hash, sizeof(entry_hash)) self._write(&entry_hash, sizeof(entry_hash))
self._write(&entry_freq, sizeof(entry_freq)) self._write(&entry_freq, sizeof(entry_freq))
self._write(&vector_index, sizeof(vector_index)) self._write(&vector_index, sizeof(vector_index))
@ -561,7 +612,9 @@ cdef class Writer:
cdef int write_alias_length(self, int64_t alias_length) except -1: cdef int write_alias_length(self, int64_t alias_length) except -1:
self._write(&alias_length, sizeof(alias_length)) self._write(&alias_length, sizeof(alias_length))
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1: cdef int write_alias_header(
self, hash_t alias_hash, int64_t candidate_length
) except -1:
self._write(&alias_hash, sizeof(alias_hash)) self._write(&alias_hash, sizeof(alias_hash))
self._write(&candidate_length, sizeof(candidate_length)) self._write(&candidate_length, sizeof(candidate_length))
@ -577,16 +630,19 @@ cdef class Writer:
cdef class Reader: cdef class Reader:
def __init__(self, path): def __init__(self, path):
content = bytes(path) content = bytes(path)
cdef bytes bytes_loc = content.encode('utf8') if type(content) == str else content cdef bytes bytes_loc = content.encode('utf8') \
if type(content) == str else content
self._fp = fopen(<char*>bytes_loc, 'rb') self._fp = fopen(<char*>bytes_loc, 'rb')
if not self._fp: if not self._fp:
PyErr_SetFromErrno(IOError) PyErr_SetFromErrno(IOError)
status = fseek(self._fp, 0, 0) # this can be 0 if there is no header fseek(self._fp, 0, 0) # this can be 0 if there is no header
def __dealloc__(self): def __dealloc__(self):
fclose(self._fp) fclose(self._fp)
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1: cdef int read_header(
self, int64_t* nr_entries, int64_t* entity_vector_length
) except -1:
status = self._read(nr_entries, sizeof(int64_t)) status = self._read(nr_entries, sizeof(int64_t))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
@ -606,7 +662,9 @@ cdef class Reader:
return 0 # end of file return 0 # end of file
raise IOError(Errors.E145.format(param="vector element")) raise IOError(Errors.E145.format(param="vector element"))
cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1: cdef int read_entry(
self, hash_t* entity_hash, float* freq, int32_t* vector_index
) except -1:
status = self._read(entity_hash, sizeof(hash_t)) status = self._read(entity_hash, sizeof(hash_t))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
@ -637,7 +695,9 @@ cdef class Reader:
return 0 # end of file return 0 # end of file
raise IOError(Errors.E145.format(param="alias length")) raise IOError(Errors.E145.format(param="alias length"))
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1: cdef int read_alias_header(
self, hash_t* alias_hash, int64_t* candidate_length
) except -1:
status = self._read(alias_hash, sizeof(hash_t)) status = self._read(alias_hash, sizeof(hash_t))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):