diff --git a/spacy/lang/zh/__init__.py b/spacy/lang/zh/__init__.py index f9887a4df..69c7b644d 100644 --- a/spacy/lang/zh/__init__.py +++ b/spacy/lang/zh/__init__.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Dict, Any from enum import Enum import tempfile import srsly @@ -137,6 +137,18 @@ class ChineseTokenizer(DummyTokenizer): warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter) warnings.warn(warn_msg) + def _get_config(self) -> Dict[str, Any]: + return { + "segmenter": self.segmenter, + "pkuseg_model": self.pkuseg_model, + "pkuseg_user_dict": self.pkuseg_user_dict, + } + + def _set_config(self, config: Dict[str, Any] = {}) -> None: + self.segmenter = config.get("segmenter", Segmenter.char) + self.pkuseg_model = config.get("pkuseg_model", None) + self.pkuseg_user_dict = config.get("pkuseg_user_dict", "default") + def to_bytes(self, **kwargs): pkuseg_features_b = b"" pkuseg_weights_b = b"" @@ -173,6 +185,7 @@ class ChineseTokenizer(DummyTokenizer): sorted(list(self.pkuseg_seg.postprocesser.other_words)), ) serializers = { + "cfg": lambda: srsly.json_dumps(self._get_config()), "pkuseg_features": lambda: pkuseg_features_b, "pkuseg_weights": lambda: pkuseg_weights_b, "pkuseg_processors": lambda: srsly.msgpack_dumps(pkuseg_processors_data), @@ -192,6 +205,7 @@ class ChineseTokenizer(DummyTokenizer): pkuseg_data["processors_data"] = srsly.msgpack_loads(b) deserializers = { + "cfg": lambda b: self._set_config(srsly.json_loads(b)), "pkuseg_features": deserialize_pkuseg_features, "pkuseg_weights": deserialize_pkuseg_weights, "pkuseg_processors": deserialize_pkuseg_processors, @@ -256,6 +270,7 @@ class ChineseTokenizer(DummyTokenizer): srsly.write_msgpack(path, data) serializers = { + "cfg": lambda p: srsly.write_json(p, self._get_config()), "pkuseg_model": lambda p: save_pkuseg_model(p), "pkuseg_processors": lambda p: save_pkuseg_processors(p), } @@ -291,6 +306,7 @@ class ChineseTokenizer(DummyTokenizer): self.pkuseg_seg.postprocesser.other_words = set(other_words) serializers = { + "cfg": lambda p: self._set_config(srsly.read_json(p)), "pkuseg_model": lambda p: load_pkuseg_model(p), "pkuseg_processors": lambda p: load_pkuseg_processors(p), } diff --git a/spacy/tests/conftest.py b/spacy/tests/conftest.py index 3a9a1f26b..23fc5e98f 100644 --- a/spacy/tests/conftest.py +++ b/spacy/tests/conftest.py @@ -282,6 +282,7 @@ def zh_tokenizer_jieba(): @pytest.fixture(scope="session") def zh_tokenizer_pkuseg(): pytest.importorskip("pkuseg") + pytest.importorskip("pickle5") config = { "@tokenizers": "spacy.zh.ChineseTokenizer", "segmenter": "pkuseg",