diff --git a/spacy/errors.py b/spacy/errors.py index 09b722a7b..f8fb7dd8b 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -672,14 +672,22 @@ class Errors: E999 = ("Unable to merge the `Doc` objects because they do not all share " "the same `Vocab`.") E1000 = ("The Chinese word segmenter is pkuseg but no pkuseg model was " - "specified. Provide the name of a pretrained model or the path to " - "a model when initializing the pipeline:\n" + "loaded. Provide the name of a pretrained model or the path to " + "a model and initialize the pipeline:\n\n" 'config = {\n' - ' "@tokenizers": "spacy.zh.ChineseTokenizer",\n' - ' "segmenter": "pkuseg",\n' - ' "pkuseg_model": "default", # or "/path/to/pkuseg_model" \n' + ' "nlp": {\n' + ' "tokenizer": {\n' + ' "@tokenizers": "spacy.zh.ChineseTokenizer",\n' + ' "segmenter": "pkuseg",\n' + ' }\n' + ' },\n' + ' "initialize": {"tokenizer": {\n' + ' "pkuseg_model": "default", # or /path/to/model\n' + ' }\n' + ' },\n' '}\n' - 'nlp = Chinese.from_config({"nlp": {"tokenizer": config}})') + 'nlp = Chinese.from_config(config)\n' + 'nlp.initialize()') E1001 = ("Target token outside of matched span for match with tokens " "'{span}' and offset '{index}' matched by patterns '{patterns}'.") E1002 = ("Span index out of range.") diff --git a/spacy/lang/zh/__init__.py b/spacy/lang/zh/__init__.py index a413d86eb..ecabb6555 100644 --- a/spacy/lang/zh/__init__.py +++ b/spacy/lang/zh/__init__.py @@ -59,32 +59,13 @@ class ChineseTokenizer(DummyTokenizer): self, nlp: Language, segmenter: Segmenter = Segmenter.char, - pkuseg_model: Optional[str] = None, - pkuseg_user_dict: Optional[str] = None, ): self.vocab = nlp.vocab if isinstance(segmenter, Segmenter): segmenter = segmenter.value self.segmenter = segmenter - self.pkuseg_model = pkuseg_model - self.pkuseg_user_dict = pkuseg_user_dict self.pkuseg_seg = None self.jieba_seg = None - self.configure_segmenter(segmenter) - - def initialize( - self, - get_examples: Callable[[], Iterable[Example]], - *, - nlp: Optional[Language], - pkuseg_model: Optional[str] = None, - pkuseg_user_dict: Optional[str] = None - ): - self.pkuseg_model = pkuseg_model - self.pkuseg_user_dict = pkuseg_user_dict - self.configure_segmenter(self.segmenter) - - def configure_segmenter(self, segmenter: str): if segmenter not in Segmenter.values(): warn_msg = Warnings.W103.format( lang="Chinese", @@ -94,12 +75,21 @@ class ChineseTokenizer(DummyTokenizer): ) warnings.warn(warn_msg) self.segmenter = Segmenter.char - self.jieba_seg = try_jieba_import(self.segmenter) - self.pkuseg_seg = try_pkuseg_import( - self.segmenter, - pkuseg_model=self.pkuseg_model, - pkuseg_user_dict=self.pkuseg_user_dict, - ) + if segmenter == Segmenter.jieba: + self.jieba_seg = try_jieba_import() + + def initialize( + self, + get_examples: Callable[[], Iterable[Example]], + *, + nlp: Optional[Language], + pkuseg_model: Optional[str] = None, + pkuseg_user_dict: str = "default", + ): + if self.segmenter == Segmenter.pkuseg: + self.pkuseg_seg = try_pkuseg_import( + pkuseg_model=pkuseg_model, pkuseg_user_dict=pkuseg_user_dict, + ) def __call__(self, text: str) -> Doc: if self.segmenter == Segmenter.jieba: @@ -154,14 +144,10 @@ class ChineseTokenizer(DummyTokenizer): def _get_config(self) -> Dict[str, Any]: return { "segmenter": self.segmenter, - "pkuseg_model": self.pkuseg_model, - "pkuseg_user_dict": self.pkuseg_user_dict, } def _set_config(self, config: Dict[str, Any] = {}) -> None: self.segmenter = config.get("segmenter", Segmenter.char) - self.pkuseg_model = config.get("pkuseg_model", None) - self.pkuseg_user_dict = config.get("pkuseg_user_dict", "default") def to_bytes(self, **kwargs): pkuseg_features_b = b"" @@ -339,42 +325,33 @@ class Chinese(Language): Defaults = ChineseDefaults -def try_jieba_import(segmenter: str) -> None: +def try_jieba_import() -> None: try: import jieba - if segmenter == Segmenter.jieba: - # segment a short text to have jieba initialize its cache in advance - list(jieba.cut("作为", cut_all=False)) + # segment a short text to have jieba initialize its cache in advance + list(jieba.cut("作为", cut_all=False)) return jieba except ImportError: - if segmenter == Segmenter.jieba: - msg = ( - "Jieba not installed. To use jieba, install it with `pip " - " install jieba` or from https://github.com/fxsjy/jieba" - ) - raise ImportError(msg) from None + msg = ( + "Jieba not installed. To use jieba, install it with `pip " + " install jieba` or from https://github.com/fxsjy/jieba" + ) + raise ImportError(msg) from None -def try_pkuseg_import( - segmenter: str, pkuseg_model: Optional[str], pkuseg_user_dict: str -) -> None: +def try_pkuseg_import(pkuseg_model: str, pkuseg_user_dict: str) -> None: try: import pkuseg - if pkuseg_model is None: - return None - else: - return pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict) + return pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict) except ImportError: - if segmenter == Segmenter.pkuseg: - msg = "pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG - raise ImportError(msg) from None + msg = "pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG + raise ImportError(msg) from None except FileNotFoundError: - if segmenter == Segmenter.pkuseg: - msg = "Unable to load pkuseg model from: " + pkuseg_model - raise FileNotFoundError(msg) from None + msg = "Unable to load pkuseg model from: " + pkuseg_model + raise FileNotFoundError(msg) from None def _get_pkuseg_trie_data(node, path=""): diff --git a/spacy/tests/conftest.py b/spacy/tests/conftest.py index 6cf019173..bcf582388 100644 --- a/spacy/tests/conftest.py +++ b/spacy/tests/conftest.py @@ -272,10 +272,14 @@ def zh_tokenizer_char(): def zh_tokenizer_jieba(): pytest.importorskip("jieba") config = { - "@tokenizers": "spacy.zh.ChineseTokenizer", - "segmenter": "jieba", + "nlp": { + "tokenizer": { + "@tokenizers": "spacy.zh.ChineseTokenizer", + "segmenter": "jieba", + } + } } - nlp = get_lang_class("zh").from_config({"nlp": {"tokenizer": config}}) + nlp = get_lang_class("zh").from_config(config) return nlp.tokenizer @@ -290,7 +294,10 @@ def zh_tokenizer_pkuseg(): "segmenter": "pkuseg", } }, - "initialize": {"tokenizer": {"pkuseg_model": "default"}}, + "initialize": {"tokenizer": { + "pkuseg_model": "default", + } + }, } nlp = get_lang_class("zh").from_config(config) nlp.initialize() diff --git a/spacy/tests/lang/zh/test_serialize.py b/spacy/tests/lang/zh/test_serialize.py index 5491314e2..58c084ec8 100644 --- a/spacy/tests/lang/zh/test_serialize.py +++ b/spacy/tests/lang/zh/test_serialize.py @@ -28,9 +28,17 @@ def test_zh_tokenizer_serialize_jieba(zh_tokenizer_jieba): @pytest.mark.slow def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg): config = { - "@tokenizers": "spacy.zh.ChineseTokenizer", - "segmenter": "pkuseg", - "pkuseg_model": "medicine", + "nlp": { + "tokenizer": { + "@tokenizers": "spacy.zh.ChineseTokenizer", + "segmenter": "pkuseg", + } + }, + "initialize": {"tokenizer": { + "pkuseg_model": "medicine", + } + }, } - nlp = Chinese.from_config({"nlp": {"tokenizer": config}}) + nlp = Chinese.from_config(config) + nlp.initialize() zh_tokenizer_serialize(nlp.tokenizer)