diff --git a/spacy/lang/zh/__init__.py b/spacy/lang/zh/__init__.py index 752f77d11..457502e21 100644 --- a/spacy/lang/zh/__init__.py +++ b/spacy/lang/zh/__init__.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict, Any, Callable, Iterable from enum import Enum import tempfile import srsly @@ -10,7 +10,7 @@ from ...errors import Warnings, Errors from ...language import Language from ...scorer import Scorer from ...tokens import Doc -from ...training import validate_examples +from ...training import validate_examples, Example from ...util import DummyTokenizer, registry from .lex_attrs import LEX_ATTRS from .stop_words import STOP_WORDS @@ -28,6 +28,10 @@ DEFAULT_CONFIG = """ [nlp.tokenizer] @tokenizers = "spacy.zh.ChineseTokenizer" segmenter = "char" + +[initialize] + +[initialize.tokenizer] pkuseg_model = null pkuseg_user_dict = "default" """ @@ -44,18 +48,9 @@ class Segmenter(str, Enum): @registry.tokenizers("spacy.zh.ChineseTokenizer") -def create_chinese_tokenizer( - segmenter: Segmenter = Segmenter.char, - pkuseg_model: Optional[str] = None, - pkuseg_user_dict: Optional[str] = "default", -): +def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char,): def chinese_tokenizer_factory(nlp): - return ChineseTokenizer( - nlp, - segmenter=segmenter, - pkuseg_model=pkuseg_model, - pkuseg_user_dict=pkuseg_user_dict, - ) + return ChineseTokenizer(nlp, segmenter=segmenter) return chinese_tokenizer_factory @@ -78,6 +73,18 @@ class ChineseTokenizer(DummyTokenizer): self.jieba_seg = None self.configure_segmenter(segmenter) + def initialize( + self, + get_examples: Callable[[], Iterable[Example]], + *, + nlp: Optional[Language], + pkuseg_model: Optional[str] = None, + pkuseg_user_dict: Optional[str] = None + ): + self.pkuseg_model = pkuseg_model + self.pkuseg_user_dict = pkuseg_user_dict + self.configure_segmenter(self.segmenter) + def configure_segmenter(self, segmenter: str): if segmenter not in Segmenter.values(): warn_msg = Warnings.W103.format( diff --git a/spacy/tests/conftest.py b/spacy/tests/conftest.py index 23fc5e98f..6cf019173 100644 --- a/spacy/tests/conftest.py +++ b/spacy/tests/conftest.py @@ -284,11 +284,16 @@ def zh_tokenizer_pkuseg(): pytest.importorskip("pkuseg") pytest.importorskip("pickle5") config = { - "@tokenizers": "spacy.zh.ChineseTokenizer", - "segmenter": "pkuseg", - "pkuseg_model": "default", + "nlp": { + "tokenizer": { + "@tokenizers": "spacy.zh.ChineseTokenizer", + "segmenter": "pkuseg", + } + }, + "initialize": {"tokenizer": {"pkuseg_model": "default"}}, } - nlp = get_lang_class("zh").from_config({"nlp": {"tokenizer": config}}) + nlp = get_lang_class("zh").from_config(config) + nlp.initialize() return nlp.tokenizer