mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-02 19:30:19 +03:00
Switch from pickle to json for custom field serialization.
This commit is contained in:
parent
09b302507c
commit
3a58c19cc1
|
@ -1,7 +1,8 @@
|
||||||
import pickle
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Iterable, Any, Dict
|
from typing import Callable, Iterable, Any, Dict
|
||||||
|
|
||||||
|
import srsly
|
||||||
|
|
||||||
from spacy import util, Errors
|
from spacy import util, Errors
|
||||||
from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList
|
from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList
|
||||||
from spacy.kb.kb_in_memory import InMemoryLookupKB
|
from spacy.kb.kb_in_memory import InMemoryLookupKB
|
||||||
|
@ -122,18 +123,13 @@ def test_serialize_subclassed_kb():
|
||||||
if not path.is_dir():
|
if not path.is_dir():
|
||||||
raise ValueError(Errors.E928.format(loc=path))
|
raise ValueError(Errors.E928.format(loc=path))
|
||||||
|
|
||||||
def serialize_custom_fields(
|
def serialize_custom_fields(file_path: Path) -> None:
|
||||||
values: Dict[str, Any], file_path: Path
|
srsly.write_json(file_path, {"custom_field": self.custom_field})
|
||||||
) -> None:
|
|
||||||
with open(file_path, "wb") as file:
|
|
||||||
pickle.dump(values, file)
|
|
||||||
|
|
||||||
serialize = {
|
serialize = {
|
||||||
"contents": lambda p: self.write_contents(p),
|
"contents": lambda p: self.write_contents(p),
|
||||||
"strings.json": lambda p: self.vocab.strings.to_disk(p),
|
"strings.json": lambda p: self.vocab.strings.to_disk(p),
|
||||||
"custom_fields": lambda p: serialize_custom_fields(
|
"custom_fields": lambda p: serialize_custom_fields(p),
|
||||||
{"custom_field": self.custom_field}, p
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
@ -146,8 +142,7 @@ def test_serialize_subclassed_kb():
|
||||||
raise ValueError(Errors.E928.format(loc=path))
|
raise ValueError(Errors.E928.format(loc=path))
|
||||||
|
|
||||||
def deserialize_custom_fields(file_path: Path) -> None:
|
def deserialize_custom_fields(file_path: Path) -> None:
|
||||||
with open(file_path, "rb") as file:
|
self.custom_field = srsly.read_json(file_path)["custom_field"]
|
||||||
self.custom_field = pickle.load(file)["custom_field"]
|
|
||||||
|
|
||||||
deserialize: Dict[str, Callable[[Any], Any]] = {
|
deserialize: Dict[str, Callable[[Any], Any]] = {
|
||||||
"contents": lambda p: self.read_contents(p),
|
"contents": lambda p: self.read_contents(p),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user