Switch from pickle to json for custom field serialization.

This commit is contained in:
Raphael Mitsch 2023-03-01 11:55:20 +01:00
parent 09b302507c
commit 3a58c19cc1

View File

@ -1,7 +1,8 @@
import pickle
from pathlib import Path
from typing import Callable, Iterable, Any, Dict
import srsly
from spacy import util, Errors
from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList
from spacy.kb.kb_in_memory import InMemoryLookupKB
@ -122,18 +123,13 @@ def test_serialize_subclassed_kb():
if not path.is_dir():
raise ValueError(Errors.E928.format(loc=path))
def serialize_custom_fields(
values: Dict[str, Any], file_path: Path
) -> None:
with open(file_path, "wb") as file:
pickle.dump(values, file)
def serialize_custom_fields(file_path: Path) -> None:
srsly.write_json(file_path, {"custom_field": self.custom_field})
serialize = {
"contents": lambda p: self.write_contents(p),
"strings.json": lambda p: self.vocab.strings.to_disk(p),
"custom_fields": lambda p: serialize_custom_fields(
{"custom_field": self.custom_field}, p
),
"custom_fields": lambda p: serialize_custom_fields(p),
}
util.to_disk(path, serialize, exclude)
@ -146,8 +142,7 @@ def test_serialize_subclassed_kb():
raise ValueError(Errors.E928.format(loc=path))
def deserialize_custom_fields(file_path: Path) -> None:
with open(file_path, "rb") as file:
self.custom_field = pickle.load(file)["custom_field"]
self.custom_field = srsly.read_json(file_path)["custom_field"]
deserialize: Dict[str, Callable[[Any], Any]] = {
"contents": lambda p: self.read_contents(p),