mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Fix serialization of subclassed KB in tests.
This commit is contained in:
parent
f1c48baeb5
commit
322b0050c4
|
@ -1,7 +1,9 @@
|
|||
from typing import Callable
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Any, Dict
|
||||
|
||||
from spacy import util
|
||||
from spacy.util import ensure_path, registry, load_model_from_config
|
||||
from spacy import util, Errors
|
||||
from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList
|
||||
from spacy.kb.kb_in_memory import InMemoryLookupKB
|
||||
from spacy.vocab import Vocab
|
||||
from thinc.api import Config
|
||||
|
@ -112,6 +114,43 @@ def test_serialize_subclassed_kb():
|
|||
super().__init__(vocab, entity_vector_length)
|
||||
self.custom_field = custom_field
|
||||
|
||||
def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
|
||||
"""We overwrite InMemoryLookupKB.to_disk() to ensure that self.custom_field is stored as well."""
|
||||
path = ensure_path(path)
|
||||
if not path.exists():
|
||||
path.mkdir(parents=True)
|
||||
if not path.is_dir():
|
||||
raise ValueError(Errors.E928.format(loc=path))
|
||||
|
||||
def serialize_custom_fields(values: Dict[str, Any], file_path: Path) -> None:
|
||||
with open(file_path, "wb") as file:
|
||||
pickle.dump(values, file)
|
||||
|
||||
serialize = {
|
||||
"contents": lambda p: self.write_contents(p),
|
||||
"strings.json": lambda p: self.vocab.strings.to_disk(p),
|
||||
"custom_fields": lambda p: serialize_custom_fields({"custom_field": self.custom_field}, p)
|
||||
}
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
|
||||
path = ensure_path(path)
|
||||
if not path.exists():
|
||||
raise ValueError(Errors.E929.format(loc=path))
|
||||
if not path.is_dir():
|
||||
raise ValueError(Errors.E928.format(loc=path))
|
||||
|
||||
def deserialize_custom_fields(file_path: Path) -> None:
|
||||
with open(file_path, "rb") as file:
|
||||
self.custom_field = pickle.load(file)["custom_field"]
|
||||
|
||||
deserialize: Dict[str, Callable[[Any], Any]] = {
|
||||
"contents": lambda p: self.read_contents(p),
|
||||
"strings.json": lambda p: self.vocab.strings.from_disk(p),
|
||||
"custom_fields": lambda p: deserialize_custom_fields(p)
|
||||
}
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
|
||||
@registry.misc("kb_test.CustomEmptyKB.v1")
|
||||
def empty_custom_kb() -> Callable[[Vocab, int], SubInMemoryLookupKB]:
|
||||
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
|
||||
|
@ -126,7 +165,7 @@ def test_serialize_subclassed_kb():
|
|||
@registry.misc("kb_test.CustomKB.v1")
|
||||
def custom_kb(
|
||||
entity_vector_length: int, custom_field: int
|
||||
) -> Callable[[Vocab], InMemoryLookupKB]:
|
||||
) -> Callable[[Vocab], SubInMemoryLookupKB]:
|
||||
def custom_kb_factory(vocab):
|
||||
kb = SubInMemoryLookupKB(
|
||||
vocab=vocab,
|
||||
|
|
Loading…
Reference in New Issue
Block a user