Port Candidate and InMemoryCandidate to Cython.

This commit is contained in:
Raphael Mitsch 2023-03-09 14:44:41 +01:00
parent 1c937db3af
commit b476041417
6 changed files with 58 additions and 39 deletions

View File

@ -30,6 +30,8 @@ MOD_NAMES = [
"spacy.lexeme", "spacy.lexeme",
"spacy.vocab", "spacy.vocab",
"spacy.attrs", "spacy.attrs",
"spacy.kb.candidate",
# "spacy.kb.inmemorycandidate",
"spacy.kb.kb", "spacy.kb.kb",
"spacy.kb.kb_in_memory", "spacy.kb.kb_in_memory",
"spacy.ml.tb_framework", "spacy.ml.tb_framework",

View File

@ -2,4 +2,5 @@ from .kb import KnowledgeBase
from .kb_in_memory import InMemoryLookupKB from .kb_in_memory import InMemoryLookupKB
from .candidate import Candidate, InMemoryCandidate from .candidate import Candidate, InMemoryCandidate
__all__ = ["KnowledgeBase", "InMemoryLookupKB", "Candidate", "InMemoryCandidate"] __all__ = ["KnowledgeBase", "InMemoryLookupKB", "Candidate", "InMemoryCandidate"]

17
spacy/kb/candidate.pxd Normal file
View File

@ -0,0 +1,17 @@
from libcpp.vector cimport vector
from .kb_in_memory cimport InMemoryLookupKB
from ..typedefs cimport hash_t
cdef class Candidate:
cdef readonly str _entity_id_
cdef readonly hash_t _entity_id
cdef readonly str _mention
cpdef vector[float] _entity_vector
cdef float _prior_prob
cdef class InMemoryCandidate(Candidate):
cdef readonly InMemoryLookupKB _kb
cdef hash_t _entity_hash
cdef float _entity_freq
cdef hash_t _alias_hash

View File

@ -1,10 +1,12 @@
import abc # cython: infer_types=True, profile=True
from typing import List, Union, Callable
from ..errors import Errors from ..typedefs cimport hash_t
from .kb cimport KnowledgeBase
from .kb_in_memory cimport InMemoryLookupKB
class Candidate(abc.ABC): cdef class Candidate:
"""A `Candidate` object refers to a textual mention that may or may not be resolved """A `Candidate` object refers to a textual mention that may or may not be resolved
to a specific `entity_id` from a Knowledge Base. This will be used as input for the entity_id linking to a specific `entity_id` from a Knowledge Base. This will be used as input for the entity_id linking
algorithm which will disambiguate the various candidates to the correct one. algorithm which will disambiguate the various candidates to the correct one.
@ -16,8 +18,8 @@ class Candidate(abc.ABC):
def __init__( def __init__(
self, self,
mention: str, mention: str,
entity_id: Union[str, int], entity_id: str,
entity_vector: List[float], entity_vector: vector[float],
prior_prob: float, prior_prob: float,
): ):
"""Initializes properties of `Candidate` instance. """Initializes properties of `Candidate` instance.
@ -30,22 +32,23 @@ class Candidate(abc.ABC):
doesn't) it might be better to eschew this information and always supply the same value. doesn't) it might be better to eschew this information and always supply the same value.
""" """
self._mention = mention self._mention = mention
self._entity_id = entity_id self._entity_id_ = entity_id
# Note that hashing an int value yields the same int value. # Note that hashing an int value yields the same int value.
self._entity_id_hash = hash(entity_id) self._entity_id = hash(entity_id)
self._entity_vector = entity_vector self._entity_vector = entity_vector
self._prior_prob = prior_prob self._prior_prob = prior_prob
# todo raise exception if this is instantiated class
@property @property
def entity_id(self) -> int: def entity_id(self) -> int:
"""RETURNS (int): Numerical representation of entity ID (if entity ID is numerical, this is just the entity ID, """RETURNS (int): Numerical representation of entity ID (if entity ID is numerical, this is just the entity ID,
otherwise the hash of the entity ID string).""" otherwise the hash of the entity ID string)."""
return self._entity_id_hash return self._entity_id
@property @property
def entity_id_(self) -> str: def entity_id_(self) -> str:
"""RETURNS (str): String representation of entity ID.""" """RETURNS (str): String representation of entity ID."""
return str(self._entity_id) return self._entity_id_
@property @property
def mention(self) -> str: def mention(self) -> str:
@ -53,8 +56,8 @@ class Candidate(abc.ABC):
return self._mention return self._mention
@property @property
def entity_vector(self) -> List[float]: def entity_vector(self) -> vector[float]:
"""RETURNS (List[float]): Entity vector.""" """RETURNS (vector[float]): Entity vector."""
return self._entity_vector return self._entity_vector
@property @property
@ -63,20 +66,20 @@ class Candidate(abc.ABC):
return self._prior_prob return self._prior_prob
class InMemoryCandidate(Candidate): cdef class InMemoryCandidate(Candidate):
"""Candidate for InMemoryLookupKB.""" """Candidate for InMemoryLookupKB."""
def __init__( def __init__(
self, self,
hash_to_str: Callable[[int], str], kb: InMemoryLookupKB,
entity_id: int, entity_hash: int,
mention: str, mention: str,
entity_vector: List[float], entity_vector: vector[float],
prior_prob: float, prior_prob: float,
entity_freq: int, entity_freq: float
): ):
""" """
hash_to_str (Callable[[int], str]): Callable retrieving entity name from provided entity/vocab hash. kb (InMemoryLookupKB]): InMemoryLookupKB instance.
entity_id (int): Entity ID as hash that can be looked up with InMemoryKB.vocab.strings.__getitem__(). entity_id (int): Entity ID as hash that can be looked up with InMemoryKB.vocab.strings.__getitem__().
entity_freq (int): Entity frequency in KB corpus. entity_freq (int): Entity frequency in KB corpus.
entity_vector (List[float]): Entity embedding. entity_vector (List[float]): Entity embedding.
@ -88,24 +91,19 @@ class InMemoryCandidate(Candidate):
""" """
super().__init__( super().__init__(
mention=mention, mention=mention,
entity_id=entity_id, entity_id=kb.vocab.strings[entity_hash],
entity_vector=entity_vector, entity_vector=entity_vector,
prior_prob=prior_prob, prior_prob=prior_prob,
) )
self._hash_to_str = hash_to_str self._kb = kb
self._entity_id = entity_hash
self._entity_freq = entity_freq self._entity_freq = entity_freq
if not isinstance(self._entity_id, int):
raise ValueError(
Errors.E4006.format(exp_type="int", found_type=str(type(entity_id)))
)
self._entity_id_str = self._hash_to_str(self._entity_id)
@property
def entity_freq(self) -> float:
"""RETURNS (float): Relative entity frequency."""
return self._entity_freq
@property @property
def entity_id_(self) -> str: def entity_id_(self) -> str:
"""RETURNS (str): String representation of entity ID.""" """RETURNS (str): ID/name of this entity in the KB"""
return self._entity_id_str return self._kb.vocab.strings[self._entity_id]
@property
def entity_freq(self) -> float:
return self._entity_freq

View File

@ -243,8 +243,8 @@ cdef class InMemoryLookupKB(KnowledgeBase):
return [ return [
InMemoryCandidate( InMemoryCandidate(
hash_to_str=self.vocab.strings.__getitem__, kb=self,
entity_id=self._entries[entry_index].entity_hash, entity_hash=self._entries[entry_index].entity_hash,
mention=alias, mention=alias,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index], entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
prior_prob=prior_prob, prior_prob=prior_prob,

View File

@ -465,16 +465,17 @@ def test_candidate_generation(nlp):
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates # test the size of the relevant candidates
adam_ent_cands = get_candidates(mykb, adam_ent)
assert len(get_candidates(mykb, douglas_ent)) == 2 assert len(get_candidates(mykb, douglas_ent)) == 2
assert len(get_candidates(mykb, adam_ent)) == 1 assert len(adam_ent_cands) == 1
assert len(get_candidates(mykb, Adam_ent)) == 0 # default case sensitive assert len(get_candidates(mykb, Adam_ent)) == 0 # default case sensitive
assert len(get_candidates(mykb, shrubbery_ent)) == 0 assert len(get_candidates(mykb, shrubbery_ent)) == 0
# test the content of the candidates # test the content of the candidates
assert get_candidates(mykb, adam_ent)[0].entity_id_ == "Q2" assert adam_ent_cands[0].entity_id_ == "Q2"
assert get_candidates(mykb, adam_ent)[0].mention == "adam" assert adam_ent_cands[0].mention == "adam"
assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12) assert_almost_equal(adam_ent_cands[0].entity_freq, 12)
assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9) assert_almost_equal(adam_ent_cands[0].prior_prob, 0.9)
def test_el_pipe_configuration(nlp): def test_el_pipe_configuration(nlp):