diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index e374cf94d..9317629e2 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -1,7 +1,7 @@ # cython: infer_types=True, profile=True from pathlib import Path -from typing import Iterable, Tuple, Union +from typing import Iterable, Tuple, Union, TypeVar, Type from cymem.cymem cimport Pool from .candidate import Candidate @@ -18,6 +18,8 @@ cdef class KnowledgeBase: DOCS: https://spacy.io/api/kb """ + _KBType = TypeVar("_KBType", bound=KnowledgeBase) + def __init__(self, vocab: Vocab, entity_vector_length: int): """Create a KnowledgeBase.""" # Make sure abstract KB is not instantiated. @@ -106,3 +108,25 @@ cdef class KnowledgeBase: raise NotImplementedError( Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) ) + + @classmethod + def generate_from_disk( + cls: Type[_KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList() + ) -> _KBType: + """ + Factory method for generating KnowledgeBase subclass instance from file. + path (Union[str, Path]): Target file path. + exclude (Iterable[str]): List of components to exclude. + return (_KBType): Instance of KnowledgeBase subclass generated from file. + """ + raise NotImplementedError( + Errors.E1044.format(parent="KnowledgeBase", method="generate_from_disk", name=cls.__name__) + ) + + def __len__(self) -> int: + """Returns number of entities in the KnowledgeBase. + RETURNS (int): Number of entities in the KnowledgeBase. + """ + raise NotImplementedError( + Errors.E1044.format(parent="KnowledgeBase", method="__len__", name=self.__name__) + ) diff --git a/spacy/kb/kb_in_memory.pyx b/spacy/kb/kb_in_memory.pyx index 0b8e3f2f4..df4c1d4f3 100644 --- a/spacy/kb/kb_in_memory.pyx +++ b/spacy/kb/kb_in_memory.pyx @@ -46,6 +46,14 @@ cdef class InMemoryLookupKB(KnowledgeBase): self._alias_index = PreshMap(nr_aliases + 1) self._aliases_table = alias_vec(nr_aliases + 1) + @classmethod + def generate_from_disk( + cls, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList() + ) -> "InMemoryLookupKB": + kb = InMemoryLookupKB(vocab=Vocab(strings=["."]), entity_vector_length=1) + kb.from_disk(path) + return kb + def __len__(self): return self.get_size_entities()