Fix serialization of subclassed KB in tests.

This commit is contained in:
Raphael Mitsch 2023-03-01 11:29:23 +01:00
parent f1c48baeb5
commit 322b0050c4

View File

@ -1,7 +1,9 @@
from typing import Callable
import pickle
from pathlib import Path
from typing import Callable, Iterable, Any, Dict
from spacy import util
from spacy.util import ensure_path, registry, load_model_from_config
from spacy import util, Errors
from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList
from spacy.kb.kb_in_memory import InMemoryLookupKB
from spacy.vocab import Vocab
from thinc.api import Config
@ -112,6 +114,43 @@ def test_serialize_subclassed_kb():
super().__init__(vocab, entity_vector_length)
self.custom_field = custom_field
def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
"""We overwrite InMemoryLookupKB.to_disk() to ensure that self.custom_field is stored as well."""
path = ensure_path(path)
if not path.exists():
path.mkdir(parents=True)
if not path.is_dir():
raise ValueError(Errors.E928.format(loc=path))
def serialize_custom_fields(values: Dict[str, Any], file_path: Path) -> None:
with open(file_path, "wb") as file:
pickle.dump(values, file)
serialize = {
"contents": lambda p: self.write_contents(p),
"strings.json": lambda p: self.vocab.strings.to_disk(p),
"custom_fields": lambda p: serialize_custom_fields({"custom_field": self.custom_field}, p)
}
util.to_disk(path, serialize, exclude)
def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
path = ensure_path(path)
if not path.exists():
raise ValueError(Errors.E929.format(loc=path))
if not path.is_dir():
raise ValueError(Errors.E928.format(loc=path))
def deserialize_custom_fields(file_path: Path) -> None:
with open(file_path, "rb") as file:
self.custom_field = pickle.load(file)["custom_field"]
deserialize: Dict[str, Callable[[Any], Any]] = {
"contents": lambda p: self.read_contents(p),
"strings.json": lambda p: self.vocab.strings.from_disk(p),
"custom_fields": lambda p: deserialize_custom_fields(p)
}
util.from_disk(path, deserialize, exclude)
@registry.misc("kb_test.CustomEmptyKB.v1")
def empty_custom_kb() -> Callable[[Vocab, int], SubInMemoryLookupKB]:
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
@ -126,7 +165,7 @@ def test_serialize_subclassed_kb():
@registry.misc("kb_test.CustomKB.v1")
def custom_kb(
entity_vector_length: int, custom_field: int
) -> Callable[[Vocab], InMemoryLookupKB]:
) -> Callable[[Vocab], SubInMemoryLookupKB]:
def custom_kb_factory(vocab):
kb = SubInMemoryLookupKB(
vocab=vocab,