Add auxiliary methods for KB (and InMemoryLookupKB).

This commit is contained in:
Raphael Mitsch 2023-02-28 17:29:01 +01:00
parent 50b34751eb
commit 68f6015e2e
2 changed files with 33 additions and 1 deletions

View File

@ -1,7 +1,7 @@
# cython: infer_types=True, profile=True
from pathlib import Path
from typing import Iterable, Tuple, Union
from typing import Iterable, Tuple, Union, TypeVar, Type
from cymem.cymem cimport Pool
from .candidate import Candidate
@ -18,6 +18,8 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb
"""
_KBType = TypeVar("_KBType", bound=KnowledgeBase)
def __init__(self, vocab: Vocab, entity_vector_length: int):
"""Create a KnowledgeBase."""
# Make sure abstract KB is not instantiated.
@ -106,3 +108,25 @@ cdef class KnowledgeBase:
raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__)
)
@classmethod
def generate_from_disk(
cls: Type[_KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
) -> _KBType:
"""
Factory method for generating KnowledgeBase subclass instance from file.
path (Union[str, Path]): Target file path.
exclude (Iterable[str]): List of components to exclude.
return (_KBType): Instance of KnowledgeBase subclass generated from file.
"""
raise NotImplementedError(
Errors.E1044.format(parent="KnowledgeBase", method="generate_from_disk", name=cls.__name__)
)
def __len__(self) -> int:
"""Returns number of entities in the KnowledgeBase.
RETURNS (int): Number of entities in the KnowledgeBase.
"""
raise NotImplementedError(
Errors.E1044.format(parent="KnowledgeBase", method="__len__", name=self.__name__)
)

View File

@ -46,6 +46,14 @@ cdef class InMemoryLookupKB(KnowledgeBase):
self._alias_index = PreshMap(nr_aliases + 1)
self._aliases_table = alias_vec(nr_aliases + 1)
@classmethod
def generate_from_disk(
cls, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
) -> "InMemoryLookupKB":
kb = InMemoryLookupKB(vocab=Vocab(strings=["."]), entity_vector_length=1)
kb.from_disk(path)
return kb
def __len__(self):
return self.get_size_entities()