mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 08:42:28 +03:00
Fix serialization of subclassed KB in tests.
This commit is contained in:
parent
f1c48baeb5
commit
322b0050c4
|
@ -1,7 +1,9 @@
|
||||||
from typing import Callable
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Iterable, Any, Dict
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util, Errors
|
||||||
from spacy.util import ensure_path, registry, load_model_from_config
|
from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList
|
||||||
from spacy.kb.kb_in_memory import InMemoryLookupKB
|
from spacy.kb.kb_in_memory import InMemoryLookupKB
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from thinc.api import Config
|
from thinc.api import Config
|
||||||
|
@ -112,6 +114,43 @@ def test_serialize_subclassed_kb():
|
||||||
super().__init__(vocab, entity_vector_length)
|
super().__init__(vocab, entity_vector_length)
|
||||||
self.custom_field = custom_field
|
self.custom_field = custom_field
|
||||||
|
|
||||||
|
def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
|
||||||
|
"""We overwrite InMemoryLookupKB.to_disk() to ensure that self.custom_field is stored as well."""
|
||||||
|
path = ensure_path(path)
|
||||||
|
if not path.exists():
|
||||||
|
path.mkdir(parents=True)
|
||||||
|
if not path.is_dir():
|
||||||
|
raise ValueError(Errors.E928.format(loc=path))
|
||||||
|
|
||||||
|
def serialize_custom_fields(values: Dict[str, Any], file_path: Path) -> None:
|
||||||
|
with open(file_path, "wb") as file:
|
||||||
|
pickle.dump(values, file)
|
||||||
|
|
||||||
|
serialize = {
|
||||||
|
"contents": lambda p: self.write_contents(p),
|
||||||
|
"strings.json": lambda p: self.vocab.strings.to_disk(p),
|
||||||
|
"custom_fields": lambda p: serialize_custom_fields({"custom_field": self.custom_field}, p)
|
||||||
|
}
|
||||||
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
|
||||||
|
path = ensure_path(path)
|
||||||
|
if not path.exists():
|
||||||
|
raise ValueError(Errors.E929.format(loc=path))
|
||||||
|
if not path.is_dir():
|
||||||
|
raise ValueError(Errors.E928.format(loc=path))
|
||||||
|
|
||||||
|
def deserialize_custom_fields(file_path: Path) -> None:
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
self.custom_field = pickle.load(file)["custom_field"]
|
||||||
|
|
||||||
|
deserialize: Dict[str, Callable[[Any], Any]] = {
|
||||||
|
"contents": lambda p: self.read_contents(p),
|
||||||
|
"strings.json": lambda p: self.vocab.strings.from_disk(p),
|
||||||
|
"custom_fields": lambda p: deserialize_custom_fields(p)
|
||||||
|
}
|
||||||
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
|
||||||
@registry.misc("kb_test.CustomEmptyKB.v1")
|
@registry.misc("kb_test.CustomEmptyKB.v1")
|
||||||
def empty_custom_kb() -> Callable[[Vocab, int], SubInMemoryLookupKB]:
|
def empty_custom_kb() -> Callable[[Vocab, int], SubInMemoryLookupKB]:
|
||||||
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
|
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
|
||||||
|
@ -126,7 +165,7 @@ def test_serialize_subclassed_kb():
|
||||||
@registry.misc("kb_test.CustomKB.v1")
|
@registry.misc("kb_test.CustomKB.v1")
|
||||||
def custom_kb(
|
def custom_kb(
|
||||||
entity_vector_length: int, custom_field: int
|
entity_vector_length: int, custom_field: int
|
||||||
) -> Callable[[Vocab], InMemoryLookupKB]:
|
) -> Callable[[Vocab], SubInMemoryLookupKB]:
|
||||||
def custom_kb_factory(vocab):
|
def custom_kb_factory(vocab):
|
||||||
kb = SubInMemoryLookupKB(
|
kb = SubInMemoryLookupKB(
|
||||||
vocab=vocab,
|
vocab=vocab,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user