Minor fixes

* Put `cfg` back in serialization
* Add `pickle5` to pytest conf
This commit is contained in:
Adriane Boyd 2020-09-27 15:15:53 +02:00
parent 54fe871935
commit 8393dbedad
2 changed files with 18 additions and 1 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional, List from typing import Optional, List, Dict, Any
from enum import Enum from enum import Enum
import tempfile import tempfile
import srsly import srsly
@ -137,6 +137,18 @@ class ChineseTokenizer(DummyTokenizer):
warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter) warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter)
warnings.warn(warn_msg) warnings.warn(warn_msg)
def _get_config(self) -> Dict[str, Any]:
return {
"segmenter": self.segmenter,
"pkuseg_model": self.pkuseg_model,
"pkuseg_user_dict": self.pkuseg_user_dict,
}
def _set_config(self, config: Dict[str, Any] = {}) -> None:
self.segmenter = config.get("segmenter", Segmenter.char)
self.pkuseg_model = config.get("pkuseg_model", None)
self.pkuseg_user_dict = config.get("pkuseg_user_dict", "default")
def to_bytes(self, **kwargs): def to_bytes(self, **kwargs):
pkuseg_features_b = b"" pkuseg_features_b = b""
pkuseg_weights_b = b"" pkuseg_weights_b = b""
@ -173,6 +185,7 @@ class ChineseTokenizer(DummyTokenizer):
sorted(list(self.pkuseg_seg.postprocesser.other_words)), sorted(list(self.pkuseg_seg.postprocesser.other_words)),
) )
serializers = { serializers = {
"cfg": lambda: srsly.json_dumps(self._get_config()),
"pkuseg_features": lambda: pkuseg_features_b, "pkuseg_features": lambda: pkuseg_features_b,
"pkuseg_weights": lambda: pkuseg_weights_b, "pkuseg_weights": lambda: pkuseg_weights_b,
"pkuseg_processors": lambda: srsly.msgpack_dumps(pkuseg_processors_data), "pkuseg_processors": lambda: srsly.msgpack_dumps(pkuseg_processors_data),
@ -192,6 +205,7 @@ class ChineseTokenizer(DummyTokenizer):
pkuseg_data["processors_data"] = srsly.msgpack_loads(b) pkuseg_data["processors_data"] = srsly.msgpack_loads(b)
deserializers = { deserializers = {
"cfg": lambda b: self._set_config(srsly.json_loads(b)),
"pkuseg_features": deserialize_pkuseg_features, "pkuseg_features": deserialize_pkuseg_features,
"pkuseg_weights": deserialize_pkuseg_weights, "pkuseg_weights": deserialize_pkuseg_weights,
"pkuseg_processors": deserialize_pkuseg_processors, "pkuseg_processors": deserialize_pkuseg_processors,
@ -256,6 +270,7 @@ class ChineseTokenizer(DummyTokenizer):
srsly.write_msgpack(path, data) srsly.write_msgpack(path, data)
serializers = { serializers = {
"cfg": lambda p: srsly.write_json(p, self._get_config()),
"pkuseg_model": lambda p: save_pkuseg_model(p), "pkuseg_model": lambda p: save_pkuseg_model(p),
"pkuseg_processors": lambda p: save_pkuseg_processors(p), "pkuseg_processors": lambda p: save_pkuseg_processors(p),
} }
@ -291,6 +306,7 @@ class ChineseTokenizer(DummyTokenizer):
self.pkuseg_seg.postprocesser.other_words = set(other_words) self.pkuseg_seg.postprocesser.other_words = set(other_words)
serializers = { serializers = {
"cfg": lambda p: self._set_config(srsly.read_json(p)),
"pkuseg_model": lambda p: load_pkuseg_model(p), "pkuseg_model": lambda p: load_pkuseg_model(p),
"pkuseg_processors": lambda p: load_pkuseg_processors(p), "pkuseg_processors": lambda p: load_pkuseg_processors(p),
} }

View File

@ -282,6 +282,7 @@ def zh_tokenizer_jieba():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def zh_tokenizer_pkuseg(): def zh_tokenizer_pkuseg():
pytest.importorskip("pkuseg") pytest.importorskip("pkuseg")
pytest.importorskip("pickle5")
config = { config = {
"@tokenizers": "spacy.zh.ChineseTokenizer", "@tokenizers": "spacy.zh.ChineseTokenizer",
"segmenter": "pkuseg", "segmenter": "pkuseg",