WIP: Test updating Chinese tokenizer

This commit is contained in:
Ines Montani 2020-09-29 21:10:22 +02:00
parent 4f3102d09c
commit 6467a560e3
2 changed files with 29 additions and 17 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Callable, Iterable
from enum import Enum
import tempfile
import srsly
@ -10,7 +10,7 @@ from ...errors import Warnings, Errors
from ...language import Language
from ...scorer import Scorer
from ...tokens import Doc
from ...training import validate_examples
from ...training import validate_examples, Example
from ...util import DummyTokenizer, registry
from .lex_attrs import LEX_ATTRS
from .stop_words import STOP_WORDS
@ -28,6 +28,10 @@ DEFAULT_CONFIG = """
[nlp.tokenizer]
@tokenizers = "spacy.zh.ChineseTokenizer"
segmenter = "char"
[initialize]
[initialize.tokenizer]
pkuseg_model = null
pkuseg_user_dict = "default"
"""
@ -44,18 +48,9 @@ class Segmenter(str, Enum):
@registry.tokenizers("spacy.zh.ChineseTokenizer")
def create_chinese_tokenizer(
segmenter: Segmenter = Segmenter.char,
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: Optional[str] = "default",
):
def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char,):
def chinese_tokenizer_factory(nlp):
return ChineseTokenizer(
nlp,
segmenter=segmenter,
pkuseg_model=pkuseg_model,
pkuseg_user_dict=pkuseg_user_dict,
)
return ChineseTokenizer(nlp, segmenter=segmenter)
return chinese_tokenizer_factory
@ -78,6 +73,18 @@ class ChineseTokenizer(DummyTokenizer):
self.jieba_seg = None
self.configure_segmenter(segmenter)
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language],
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: Optional[str] = None
):
self.pkuseg_model = pkuseg_model
self.pkuseg_user_dict = pkuseg_user_dict
self.configure_segmenter(self.segmenter)
def configure_segmenter(self, segmenter: str):
if segmenter not in Segmenter.values():
warn_msg = Warnings.W103.format(

View File

@ -284,11 +284,16 @@ def zh_tokenizer_pkuseg():
pytest.importorskip("pkuseg")
pytest.importorskip("pickle5")
config = {
"@tokenizers": "spacy.zh.ChineseTokenizer",
"segmenter": "pkuseg",
"pkuseg_model": "default",
"nlp": {
"tokenizer": {
"@tokenizers": "spacy.zh.ChineseTokenizer",
"segmenter": "pkuseg",
}
},
"initialize": {"tokenizer": {"pkuseg_model": "default"}},
}
nlp = get_lang_class("zh").from_config({"nlp": {"tokenizer": config}})
nlp = get_lang_class("zh").from_config(config)
nlp.initialize()
return nlp.tokenizer