Switch from pickle to json for custom field serialization.

This commit is contained in:
Raphael Mitsch 2023-03-01 11:55:20 +01:00
parent 09b302507c
commit 3a58c19cc1

View File

@ -1,7 +1,8 @@
import pickle
from pathlib import Path from pathlib import Path
from typing import Callable, Iterable, Any, Dict from typing import Callable, Iterable, Any, Dict
import srsly
from spacy import util, Errors from spacy import util, Errors
from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList
from spacy.kb.kb_in_memory import InMemoryLookupKB from spacy.kb.kb_in_memory import InMemoryLookupKB
@ -122,18 +123,13 @@ def test_serialize_subclassed_kb():
if not path.is_dir(): if not path.is_dir():
raise ValueError(Errors.E928.format(loc=path)) raise ValueError(Errors.E928.format(loc=path))
def serialize_custom_fields( def serialize_custom_fields(file_path: Path) -> None:
values: Dict[str, Any], file_path: Path srsly.write_json(file_path, {"custom_field": self.custom_field})
) -> None:
with open(file_path, "wb") as file:
pickle.dump(values, file)
serialize = { serialize = {
"contents": lambda p: self.write_contents(p), "contents": lambda p: self.write_contents(p),
"strings.json": lambda p: self.vocab.strings.to_disk(p), "strings.json": lambda p: self.vocab.strings.to_disk(p),
"custom_fields": lambda p: serialize_custom_fields( "custom_fields": lambda p: serialize_custom_fields(p),
{"custom_field": self.custom_field}, p
),
} }
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
@ -146,8 +142,7 @@ def test_serialize_subclassed_kb():
raise ValueError(Errors.E928.format(loc=path)) raise ValueError(Errors.E928.format(loc=path))
def deserialize_custom_fields(file_path: Path) -> None: def deserialize_custom_fields(file_path: Path) -> None:
with open(file_path, "rb") as file: self.custom_field = srsly.read_json(file_path)["custom_field"]
self.custom_field = pickle.load(file)["custom_field"]
deserialize: Dict[str, Callable[[Any], Any]] = { deserialize: Dict[str, Callable[[Any], Any]] = {
"contents": lambda p: self.read_contents(p), "contents": lambda p: self.read_contents(p),