Refactor Chinese initialization

This commit is contained in:
Adriane Boyd 2020-09-30 11:46:45 +02:00
parent 34f9c26c62
commit 6b7bb32834
4 changed files with 66 additions and 66 deletions

View File

@ -672,14 +672,22 @@ class Errors:
E999 = ("Unable to merge the `Doc` objects because they do not all share "
"the same `Vocab`.")
E1000 = ("The Chinese word segmenter is pkuseg but no pkuseg model was "
"specified. Provide the name of a pretrained model or the path to "
"a model when initializing the pipeline:\n"
"loaded. Provide the name of a pretrained model or the path to "
"a model and initialize the pipeline:\n\n"
'config = {\n'
' "nlp": {\n'
' "tokenizer": {\n'
' "@tokenizers": "spacy.zh.ChineseTokenizer",\n'
' "segmenter": "pkuseg",\n'
' "pkuseg_model": "default", # or "/path/to/pkuseg_model" \n'
' }\n'
' },\n'
' "initialize": {"tokenizer": {\n'
' "pkuseg_model": "default", # or /path/to/model\n'
' }\n'
' },\n'
'}\n'
'nlp = Chinese.from_config({"nlp": {"tokenizer": config}})')
'nlp = Chinese.from_config(config)\n'
'nlp.initialize()')
E1001 = ("Target token outside of matched span for match with tokens "
"'{span}' and offset '{index}' matched by patterns '{patterns}'.")
E1002 = ("Span index out of range.")

View File

@ -59,32 +59,13 @@ class ChineseTokenizer(DummyTokenizer):
self,
nlp: Language,
segmenter: Segmenter = Segmenter.char,
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: Optional[str] = None,
):
self.vocab = nlp.vocab
if isinstance(segmenter, Segmenter):
segmenter = segmenter.value
self.segmenter = segmenter
self.pkuseg_model = pkuseg_model
self.pkuseg_user_dict = pkuseg_user_dict
self.pkuseg_seg = None
self.jieba_seg = None
self.configure_segmenter(segmenter)
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language],
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: Optional[str] = None
):
self.pkuseg_model = pkuseg_model
self.pkuseg_user_dict = pkuseg_user_dict
self.configure_segmenter(self.segmenter)
def configure_segmenter(self, segmenter: str):
if segmenter not in Segmenter.values():
warn_msg = Warnings.W103.format(
lang="Chinese",
@ -94,11 +75,20 @@ class ChineseTokenizer(DummyTokenizer):
)
warnings.warn(warn_msg)
self.segmenter = Segmenter.char
self.jieba_seg = try_jieba_import(self.segmenter)
if segmenter == Segmenter.jieba:
self.jieba_seg = try_jieba_import()
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language],
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: str = "default",
):
if self.segmenter == Segmenter.pkuseg:
self.pkuseg_seg = try_pkuseg_import(
self.segmenter,
pkuseg_model=self.pkuseg_model,
pkuseg_user_dict=self.pkuseg_user_dict,
pkuseg_model=pkuseg_model, pkuseg_user_dict=pkuseg_user_dict,
)
def __call__(self, text: str) -> Doc:
@ -154,14 +144,10 @@ class ChineseTokenizer(DummyTokenizer):
def _get_config(self) -> Dict[str, Any]:
return {
"segmenter": self.segmenter,
"pkuseg_model": self.pkuseg_model,
"pkuseg_user_dict": self.pkuseg_user_dict,
}
def _set_config(self, config: Dict[str, Any] = {}) -> None:
self.segmenter = config.get("segmenter", Segmenter.char)
self.pkuseg_model = config.get("pkuseg_model", None)
self.pkuseg_user_dict = config.get("pkuseg_user_dict", "default")
def to_bytes(self, **kwargs):
pkuseg_features_b = b""
@ -339,17 +325,15 @@ class Chinese(Language):
Defaults = ChineseDefaults
def try_jieba_import(segmenter: str) -> None:
def try_jieba_import() -> None:
try:
import jieba
if segmenter == Segmenter.jieba:
# segment a short text to have jieba initialize its cache in advance
list(jieba.cut("作为", cut_all=False))
return jieba
except ImportError:
if segmenter == Segmenter.jieba:
msg = (
"Jieba not installed. To use jieba, install it with `pip "
" install jieba` or from https://github.com/fxsjy/jieba"
@ -357,22 +341,15 @@ def try_jieba_import(segmenter: str) -> None:
raise ImportError(msg) from None
def try_pkuseg_import(
segmenter: str, pkuseg_model: Optional[str], pkuseg_user_dict: str
) -> None:
def try_pkuseg_import(pkuseg_model: str, pkuseg_user_dict: str) -> None:
try:
import pkuseg
if pkuseg_model is None:
return None
else:
return pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict)
except ImportError:
if segmenter == Segmenter.pkuseg:
msg = "pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG
raise ImportError(msg) from None
except FileNotFoundError:
if segmenter == Segmenter.pkuseg:
msg = "Unable to load pkuseg model from: " + pkuseg_model
raise FileNotFoundError(msg) from None

View File

@ -272,10 +272,14 @@ def zh_tokenizer_char():
def zh_tokenizer_jieba():
pytest.importorskip("jieba")
config = {
"nlp": {
"tokenizer": {
"@tokenizers": "spacy.zh.ChineseTokenizer",
"segmenter": "jieba",
}
nlp = get_lang_class("zh").from_config({"nlp": {"tokenizer": config}})
}
}
nlp = get_lang_class("zh").from_config(config)
return nlp.tokenizer
@ -290,7 +294,10 @@ def zh_tokenizer_pkuseg():
"segmenter": "pkuseg",
}
},
"initialize": {"tokenizer": {"pkuseg_model": "default"}},
"initialize": {"tokenizer": {
"pkuseg_model": "default",
}
},
}
nlp = get_lang_class("zh").from_config(config)
nlp.initialize()

View File

@ -28,9 +28,17 @@ def test_zh_tokenizer_serialize_jieba(zh_tokenizer_jieba):
@pytest.mark.slow
def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
config = {
"nlp": {
"tokenizer": {
"@tokenizers": "spacy.zh.ChineseTokenizer",
"segmenter": "pkuseg",
}
},
"initialize": {"tokenizer": {
"pkuseg_model": "medicine",
}
nlp = Chinese.from_config({"nlp": {"tokenizer": config}})
},
}
nlp = Chinese.from_config(config)
nlp.initialize()
zh_tokenizer_serialize(nlp.tokenizer)