Add auxiliary methods for KB (and InMemoryLookupKB).

This commit is contained in:
Raphael Mitsch 2023-02-28 17:29:01 +01:00
parent 50b34751eb
commit 68f6015e2e
2 changed files with 33 additions and 1 deletions

View File

@ -1,7 +1,7 @@
# cython: infer_types=True, profile=True # cython: infer_types=True, profile=True
from pathlib import Path from pathlib import Path
from typing import Iterable, Tuple, Union from typing import Iterable, Tuple, Union, TypeVar, Type
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from .candidate import Candidate from .candidate import Candidate
@ -18,6 +18,8 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb DOCS: https://spacy.io/api/kb
""" """
_KBType = TypeVar("_KBType", bound=KnowledgeBase)
def __init__(self, vocab: Vocab, entity_vector_length: int): def __init__(self, vocab: Vocab, entity_vector_length: int):
"""Create a KnowledgeBase.""" """Create a KnowledgeBase."""
# Make sure abstract KB is not instantiated. # Make sure abstract KB is not instantiated.
@ -106,3 +108,25 @@ cdef class KnowledgeBase:
raise NotImplementedError( raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__)
) )
@classmethod
def generate_from_disk(
cls: Type[_KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
) -> _KBType:
"""
Factory method for generating KnowledgeBase subclass instance from file.
path (Union[str, Path]): Target file path.
exclude (Iterable[str]): List of components to exclude.
return (_KBType): Instance of KnowledgeBase subclass generated from file.
"""
raise NotImplementedError(
Errors.E1044.format(parent="KnowledgeBase", method="generate_from_disk", name=cls.__name__)
)
def __len__(self) -> int:
"""Returns number of entities in the KnowledgeBase.
RETURNS (int): Number of entities in the KnowledgeBase.
"""
raise NotImplementedError(
Errors.E1044.format(parent="KnowledgeBase", method="__len__", name=self.__name__)
)

View File

@ -46,6 +46,14 @@ cdef class InMemoryLookupKB(KnowledgeBase):
self._alias_index = PreshMap(nr_aliases + 1) self._alias_index = PreshMap(nr_aliases + 1)
self._aliases_table = alias_vec(nr_aliases + 1) self._aliases_table = alias_vec(nr_aliases + 1)
@classmethod
def generate_from_disk(
cls, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
) -> "InMemoryLookupKB":
kb = InMemoryLookupKB(vocab=Vocab(strings=["."]), entity_vector_length=1)
kb.from_disk(path)
return kb
def __len__(self): def __len__(self):
return self.get_size_entities() return self.get_size_entities()