From b1d458eca743cb0ce5218deeec451a39e45cff73 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 25 Nov 2022 12:02:37 +0100 Subject: [PATCH] Add generate_from_disk() factory method. --- spacy/kb/kb.pyx | 17 ++++++++++++++++- spacy/kb/kb_in_memory.pyx | 8 ++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index fee074f0c..3ee434ab5 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -1,7 +1,7 @@ # cython: infer_types=True, profile=True from pathlib import Path -from typing import Iterable, Tuple, Union, Iterator +from typing import Iterable, Tuple, Union, Iterator, TypeVar, Type from cymem.cymem cimport Pool from .candidate import Candidate @@ -107,6 +107,21 @@ cdef class KnowledgeBase: Errors.E1044.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) ) + KBType = TypeVar("KBType", bound=KnowledgeBase) + @classmethod + def generate_from_disk( + cls: Type[KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList() + ) -> KBType: + """ + Factory method for generating KnowledgeBase instance from file. + path (Union[str, Path]): Target file path. + exclude (Iterable[str]): List of components to exclude. + return (KBType): Instance of KnowledgeBase generated from file. + """ + raise NotImplementedError( + Errors.E1044.format(parent="KnowledgeBase", method="generate_from_disk", name=cls.__name__) + ) + def __len__(self) -> int: """Returns number of entities in the KnowledgeBase. RETURNS (int): Number of entities in the KnowledgeBase. diff --git a/spacy/kb/kb_in_memory.pyx b/spacy/kb/kb_in_memory.pyx index 485e52c2f..c030b5f8e 100644 --- a/spacy/kb/kb_in_memory.pyx +++ b/spacy/kb/kb_in_memory.pyx @@ -46,6 +46,14 @@ cdef class InMemoryLookupKB(KnowledgeBase): self._alias_index = PreshMap(nr_aliases + 1) self._aliases_table = alias_vec(nr_aliases + 1) + @classmethod + def generate_from_disk( + cls, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList() + ) -> "InMemoryLookupKB": + kb = InMemoryLookupKB(vocab=Vocab(strings=["."]), entity_vector_length=1) + kb.from_disk(path) + return kb + def __len__(self): return self.get_size_entities()