Add generate_from_disk() factory method.

This commit is contained in:
Raphael Mitsch 2022-11-25 12:02:37 +01:00
parent 4eb072fa91
commit b1d458eca7
2 changed files with 24 additions and 1 deletions

View File

@ -1,7 +1,7 @@
# cython: infer_types=True, profile=True # cython: infer_types=True, profile=True
from pathlib import Path from pathlib import Path
from typing import Iterable, Tuple, Union, Iterator from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from .candidate import Candidate from .candidate import Candidate
@ -107,6 +107,21 @@ cdef class KnowledgeBase:
Errors.E1044.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) Errors.E1044.format(parent="KnowledgeBase", method="from_disk", name=self.__name__)
) )
KBType = TypeVar("KBType", bound=KnowledgeBase)
@classmethod
def generate_from_disk(
cls: Type[KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
) -> KBType:
"""
Factory method for generating KnowledgeBase instance from file.
path (Union[str, Path]): Target file path.
exclude (Iterable[str]): List of components to exclude.
return (KBType): Instance of KnowledgeBase generated from file.
"""
raise NotImplementedError(
Errors.E1044.format(parent="KnowledgeBase", method="generate_from_disk", name=cls.__name__)
)
def __len__(self) -> int: def __len__(self) -> int:
"""Returns number of entities in the KnowledgeBase. """Returns number of entities in the KnowledgeBase.
RETURNS (int): Number of entities in the KnowledgeBase. RETURNS (int): Number of entities in the KnowledgeBase.

View File

@ -46,6 +46,14 @@ cdef class InMemoryLookupKB(KnowledgeBase):
self._alias_index = PreshMap(nr_aliases + 1) self._alias_index = PreshMap(nr_aliases + 1)
self._aliases_table = alias_vec(nr_aliases + 1) self._aliases_table = alias_vec(nr_aliases + 1)
@classmethod
def generate_from_disk(
cls, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
) -> "InMemoryLookupKB":
kb = InMemoryLookupKB(vocab=Vocab(strings=["."]), entity_vector_length=1)
kb.from_disk(path)
return kb
def __len__(self): def __len__(self):
return self.get_size_entities() return self.get_size_entities()