diff --git a/spacy/tests/serialize/test_serialize_kb.py b/spacy/tests/serialize/test_serialize_kb.py index 90fc3136f..102abb4dc 100644 --- a/spacy/tests/serialize/test_serialize_kb.py +++ b/spacy/tests/serialize/test_serialize_kb.py @@ -1,7 +1,9 @@ -from typing import Callable +import pickle +from pathlib import Path +from typing import Callable, Iterable, Any, Dict -from spacy import util -from spacy.util import ensure_path, registry, load_model_from_config +from spacy import util, Errors +from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList from spacy.kb.kb_in_memory import InMemoryLookupKB from spacy.vocab import Vocab from thinc.api import Config @@ -112,6 +114,43 @@ def test_serialize_subclassed_kb(): super().__init__(vocab, entity_vector_length) self.custom_field = custom_field + def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()): + """We overwrite InMemoryLookupKB.to_disk() to ensure that self.custom_field is stored as well.""" + path = ensure_path(path) + if not path.exists(): + path.mkdir(parents=True) + if not path.is_dir(): + raise ValueError(Errors.E928.format(loc=path)) + + def serialize_custom_fields(values: Dict[str, Any], file_path: Path) -> None: + with open(file_path, "wb") as file: + pickle.dump(values, file) + + serialize = { + "contents": lambda p: self.write_contents(p), + "strings.json": lambda p: self.vocab.strings.to_disk(p), + "custom_fields": lambda p: serialize_custom_fields({"custom_field": self.custom_field}, p) + } + util.to_disk(path, serialize, exclude) + + def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()): + path = ensure_path(path) + if not path.exists(): + raise ValueError(Errors.E929.format(loc=path)) + if not path.is_dir(): + raise ValueError(Errors.E928.format(loc=path)) + + def deserialize_custom_fields(file_path: Path) -> None: + with open(file_path, "rb") as file: + self.custom_field = pickle.load(file)["custom_field"] + + deserialize: Dict[str, Callable[[Any], Any]] = { + "contents": lambda p: self.read_contents(p), + "strings.json": lambda p: self.vocab.strings.from_disk(p), + "custom_fields": lambda p: deserialize_custom_fields(p) + } + util.from_disk(path, deserialize, exclude) + @registry.misc("kb_test.CustomEmptyKB.v1") def empty_custom_kb() -> Callable[[Vocab, int], SubInMemoryLookupKB]: def empty_kb_factory(vocab: Vocab, entity_vector_length: int): @@ -126,7 +165,7 @@ def test_serialize_subclassed_kb(): @registry.misc("kb_test.CustomKB.v1") def custom_kb( entity_vector_length: int, custom_field: int - ) -> Callable[[Vocab], InMemoryLookupKB]: + ) -> Callable[[Vocab], SubInMemoryLookupKB]: def custom_kb_factory(vocab): kb = SubInMemoryLookupKB( vocab=vocab,