Add pkuseg and serialization support for Chinese (#5308)

* Add pkuseg and serialization support for Chinese

Add support for pkuseg alongside jieba

* Specify model through `Language` meta:

  * split on characters (if no word segmentation packages are installed)

```
Chinese(meta={"tokenizer": {"config": {"use_jieba": False, "use_pkuseg": False}}})
```

  * jieba (remains the default tokenizer if installed)

```
Chinese()
Chinese(meta={"tokenizer": {"config": {"use_jieba": True}}}) # explicit
```

  * pkuseg

```
Chinese(meta={"tokenizer": {"config": {"pkuseg_model": "default", "use_jieba": False, "use_pkuseg": True}}})
```

* The new tokenizer setting `require_pkuseg` is used to override
`use_jieba` default, which is intended for models that provide a pkuseg
model:

```
nlp_pkuseg = Chinese(meta={"tokenizer": {"config": {"pkuseg_model": "default", "require_pkuseg": True}}})
nlp = Chinese() # has `use_jieba` as `True` by default
nlp.from_bytes(nlp_pkuseg.to_bytes()) # `require_pkuseg` overrides `use_jieba` when calling the tokenizer
```

Add support for serialization of tokenizer settings and pkuseg model, if
loaded

* Add sorting for `Language.to_bytes()` serialization of `Language.meta`
so that the (emptied, but still present) tokenizer metadata is in a
consistent position in the serialized data

Extend tests to cover all three tokenizer configurations and
serialization

* Fix from_disk and tests without jieba or pkuseg

* Load cfg first and only show error if `use_pkuseg`
* Fix blank/default initialization in serialization tests

* Explicitly initialize jieba's cache on init

* Add serialization for pkuseg pre/postprocessors

* Reformat pkuseg install message
This commit is contained in:
adrianeboyd 2020-04-18 17:01:53 +02:00 committed by GitHub
parent fb73d4943a
commit f7471abd82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 329 additions and 62 deletions

View File

@ -1,6 +1,10 @@
# coding: utf8
from __future__ import unicode_literals
import tempfile
import srsly
from pathlib import Path
from collections import OrderedDict
from ...attrs import LANG
from ...language import Language
from ...tokens import Doc
@ -9,12 +13,19 @@ from ..tokenizer_exceptions import BASE_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
from .stop_words import STOP_WORDS
from .tag_map import TAG_MAP
from ... import util
_PKUSEG_INSTALL_MSG = "install it with `pip install pkuseg==0.0.22` or from https://github.com/lancopku/pkuseg-python"
def try_jieba_import(use_jieba):
try:
import jieba
# segment a short text to have jieba initialize its cache in advance
list(jieba.cut("作为", cut_all=False))
return jieba
except ImportError:
if use_jieba:
@ -25,59 +36,241 @@ def try_jieba_import(use_jieba):
raise ImportError(msg)
def try_pkuseg_import(use_pkuseg, pkuseg_model, pkuseg_user_dict):
try:
import pkuseg
if pkuseg_model:
return pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict)
elif use_pkuseg:
msg = (
"Chinese.use_pkuseg is True but no pkuseg model was specified. "
"Please provide the name of a pretrained model "
"or the path to a model with "
'`Chinese(meta={"tokenizer": {"config": {"pkuseg_model": name_or_path}}}).'
)
raise ValueError(msg)
except ImportError:
if use_pkuseg:
msg = (
"pkuseg not installed. Either set Chinese.use_pkuseg = False, "
"or " + _PKUSEG_INSTALL_MSG
)
raise ImportError(msg)
except FileNotFoundError:
if use_pkuseg:
msg = "Unable to load pkuseg model from: " + pkuseg_model
raise FileNotFoundError(msg)
class ChineseTokenizer(DummyTokenizer):
def __init__(self, cls, nlp=None):
def __init__(self, cls, nlp=None, config={}):
self.use_jieba = config.get("use_jieba", cls.use_jieba)
self.use_pkuseg = config.get("use_pkuseg", cls.use_pkuseg)
self.require_pkuseg = config.get("require_pkuseg", False)
self.vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp)
self.use_jieba = cls.use_jieba
self.jieba_seg = try_jieba_import(self.use_jieba)
self.pkuseg_seg = try_pkuseg_import(
self.use_pkuseg,
pkuseg_model=config.get("pkuseg_model", None),
pkuseg_user_dict=config.get("pkuseg_user_dict", "default"),
)
# remove relevant settings from config so they're not also saved in
# Language.meta
for key in ["use_jieba", "use_pkuseg", "require_pkuseg", "pkuseg_model"]:
if key in config:
del config[key]
self.tokenizer = Language.Defaults().create_tokenizer(nlp)
def __call__(self, text):
# use jieba
if self.use_jieba:
jieba_words = list(
[x for x in self.jieba_seg.cut(text, cut_all=False) if x]
)
words = [jieba_words[0]]
spaces = [False]
for i in range(1, len(jieba_words)):
word = jieba_words[i]
if word.isspace():
# second token in adjacent whitespace following a
# non-space token
if spaces[-1]:
words.append(word)
spaces.append(False)
# first space token following non-space token
elif word == " " and not words[-1].isspace():
spaces[-1] = True
# token is non-space whitespace or any whitespace following
# a whitespace token
use_jieba = self.use_jieba
use_pkuseg = self.use_pkuseg
if self.require_pkuseg:
use_jieba = False
use_pkuseg = True
if use_jieba:
words = list([x for x in self.jieba_seg.cut(text, cut_all=False) if x])
(words, spaces) = util.get_words_and_spaces(words, text)
return Doc(self.vocab, words=words, spaces=spaces)
elif use_pkuseg:
words = self.pkuseg_seg.cut(text)
(words, spaces) = util.get_words_and_spaces(words, text)
return Doc(self.vocab, words=words, spaces=spaces)
else:
# extend previous whitespace token with more whitespace
if words[-1].isspace():
words[-1] += word
# otherwise it's a new whitespace token
else:
words.append(word)
spaces.append(False)
else:
words.append(word)
spaces.append(False)
# split into individual characters
words = list(text)
(words, spaces) = util.get_words_and_spaces(words, text)
return Doc(self.vocab, words=words, spaces=spaces)
# split into individual characters
words = []
spaces = []
for token in self.tokenizer(text):
if token.text.isspace():
words.append(token.text)
spaces.append(False)
else:
words.extend(list(token.text))
spaces.extend([False] * len(token.text))
spaces[-1] = bool(token.whitespace_)
return Doc(self.vocab, words=words, spaces=spaces)
def _get_config(self):
config = OrderedDict(
(
("use_jieba", self.use_jieba),
("use_pkuseg", self.use_pkuseg),
("require_pkuseg", self.require_pkuseg),
)
)
return config
def _set_config(self, config={}):
self.use_jieba = config.get("use_jieba", False)
self.use_pkuseg = config.get("use_pkuseg", False)
self.require_pkuseg = config.get("require_pkuseg", False)
def to_bytes(self, **kwargs):
pkuseg_features_b = b""
pkuseg_weights_b = b""
pkuseg_processors_data = None
if self.pkuseg_seg:
with tempfile.TemporaryDirectory() as tempdir:
self.pkuseg_seg.feature_extractor.save(tempdir)
self.pkuseg_seg.model.save(tempdir)
tempdir = Path(tempdir)
with open(tempdir / "features.pkl", "rb") as fileh:
pkuseg_features_b = fileh.read()
with open(tempdir / "weights.npz", "rb") as fileh:
pkuseg_weights_b = fileh.read()
pkuseg_processors_data = (
_get_pkuseg_trie_data(self.pkuseg_seg.preprocesser.trie),
self.pkuseg_seg.postprocesser.do_process,
sorted(list(self.pkuseg_seg.postprocesser.common_words)),
sorted(list(self.pkuseg_seg.postprocesser.other_words)),
)
serializers = OrderedDict(
(
("cfg", lambda: srsly.json_dumps(self._get_config())),
("pkuseg_features", lambda: pkuseg_features_b),
("pkuseg_weights", lambda: pkuseg_weights_b),
(
"pkuseg_processors",
lambda: srsly.msgpack_dumps(pkuseg_processors_data),
),
)
)
return util.to_bytes(serializers, [])
def from_bytes(self, data, **kwargs):
pkuseg_features_b = b""
pkuseg_weights_b = b""
pkuseg_processors_data = None
def deserialize_pkuseg_features(b):
nonlocal pkuseg_features_b
pkuseg_features_b = b
def deserialize_pkuseg_weights(b):
nonlocal pkuseg_weights_b
pkuseg_weights_b = b
def deserialize_pkuseg_processors(b):
nonlocal pkuseg_processors_data
pkuseg_processors_data = srsly.msgpack_loads(b)
deserializers = OrderedDict(
(
("cfg", lambda b: self._set_config(srsly.json_loads(b))),
("pkuseg_features", deserialize_pkuseg_features),
("pkuseg_weights", deserialize_pkuseg_weights),
("pkuseg_processors", deserialize_pkuseg_processors),
)
)
util.from_bytes(data, deserializers, [])
if pkuseg_features_b and pkuseg_weights_b:
with tempfile.TemporaryDirectory() as tempdir:
tempdir = Path(tempdir)
with open(tempdir / "features.pkl", "wb") as fileh:
fileh.write(pkuseg_features_b)
with open(tempdir / "weights.npz", "wb") as fileh:
fileh.write(pkuseg_weights_b)
try:
import pkuseg
except ImportError:
raise ImportError(
"pkuseg not installed. To use this model, "
+ _PKUSEG_INSTALL_MSG
)
self.pkuseg_seg = pkuseg.pkuseg(str(tempdir))
if pkuseg_processors_data:
(
user_dict,
do_process,
common_words,
other_words,
) = pkuseg_processors_data
self.pkuseg_seg.preprocesser = pkuseg.Preprocesser(user_dict)
self.pkuseg_seg.postprocesser.do_process = do_process
self.pkuseg_seg.postprocesser.common_words = set(common_words)
self.pkuseg_seg.postprocesser.other_words = set(other_words)
return self
def to_disk(self, path, **kwargs):
path = util.ensure_path(path)
def save_pkuseg_model(path):
if self.pkuseg_seg:
if not path.exists():
path.mkdir(parents=True)
self.pkuseg_seg.model.save(path)
self.pkuseg_seg.feature_extractor.save(path)
def save_pkuseg_processors(path):
if self.pkuseg_seg:
data = (
_get_pkuseg_trie_data(self.pkuseg_seg.preprocesser.trie),
self.pkuseg_seg.postprocesser.do_process,
sorted(list(self.pkuseg_seg.postprocesser.common_words)),
sorted(list(self.pkuseg_seg.postprocesser.other_words)),
)
srsly.write_msgpack(path, data)
serializers = OrderedDict(
(
("cfg", lambda p: srsly.write_json(p, self._get_config())),
("pkuseg_model", lambda p: save_pkuseg_model(p)),
("pkuseg_processors", lambda p: save_pkuseg_processors(p)),
)
)
return util.to_disk(path, serializers, [])
def from_disk(self, path, **kwargs):
path = util.ensure_path(path)
def load_pkuseg_model(path):
try:
import pkuseg
except ImportError:
if self.use_pkuseg:
raise ImportError(
"pkuseg not installed. To use this model, "
+ _PKUSEG_INSTALL_MSG
)
if path.exists():
self.pkuseg_seg = pkuseg.pkuseg(path)
def load_pkuseg_processors(path):
try:
import pkuseg
except ImportError:
if self.use_pkuseg:
raise ImportError(self._pkuseg_install_msg)
if self.pkuseg_seg:
data = srsly.read_msgpack(path)
(user_dict, do_process, common_words, other_words) = data
self.pkuseg_seg.preprocesser = pkuseg.Preprocesser(user_dict)
self.pkuseg_seg.postprocesser.do_process = do_process
self.pkuseg_seg.postprocesser.common_words = set(common_words)
self.pkuseg_seg.postprocesser.other_words = set(other_words)
serializers = OrderedDict(
(
("cfg", lambda p: self._set_config(srsly.read_json(p))),
("pkuseg_model", lambda p: load_pkuseg_model(p)),
("pkuseg_processors", lambda p: load_pkuseg_processors(p)),
)
)
util.from_disk(path, serializers, [])
class ChineseDefaults(Language.Defaults):
@ -89,10 +282,11 @@ class ChineseDefaults(Language.Defaults):
tag_map = TAG_MAP
writing_system = {"direction": "ltr", "has_case": False, "has_letters": False}
use_jieba = True
use_pkuseg = False
@classmethod
def create_tokenizer(cls, nlp=None):
return ChineseTokenizer(cls, nlp)
def create_tokenizer(cls, nlp=None, config={}):
return ChineseTokenizer(cls, nlp, config=config)
class Chinese(Language):
@ -103,4 +297,13 @@ class Chinese(Language):
return self.tokenizer(text)
def _get_pkuseg_trie_data(node, path=""):
data = []
for c, child_node in sorted(node.children.items()):
data.extend(_get_pkuseg_trie_data(child_node, path + c))
if node.isword:
data.append((path, node.usertag))
return data
__all__ = ["Chinese"]

View File

@ -969,7 +969,7 @@ class Language(object):
serializers = OrderedDict()
serializers["vocab"] = lambda: self.vocab.to_bytes()
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["meta.json"] = lambda: srsly.json_dumps(OrderedDict(sorted(self.meta.items())))
for name, proc in self.pipeline:
if name in exclude:
continue

View File

@ -231,10 +231,22 @@ def yo_tokenizer():
@pytest.fixture(scope="session")
def zh_tokenizer():
def zh_tokenizer_char():
return get_lang_class("zh").Defaults.create_tokenizer(config={"use_jieba": False, "use_pkuseg": False})
@pytest.fixture(scope="session")
def zh_tokenizer_jieba():
pytest.importorskip("jieba")
return get_lang_class("zh").Defaults.create_tokenizer()
@pytest.fixture(scope="session")
def zh_tokenizer_pkuseg():
pytest.importorskip("pkuseg")
return get_lang_class("zh").Defaults.create_tokenizer(config={"pkuseg_model": "default", "use_jieba": False, "use_pkuseg": True})
@pytest.fixture(scope="session")
def hy_tokenizer():
return get_lang_class("hy").Defaults.create_tokenizer()

View File

@ -0,0 +1,38 @@
# coding: utf-8
from __future__ import unicode_literals
import pytest
from spacy.lang.zh import Chinese
from ...util import make_tempdir
def zh_tokenizer_serialize(zh_tokenizer):
tokenizer_bytes = zh_tokenizer.to_bytes()
nlp = Chinese(meta={"tokenizer": {"config": {"use_jieba": False}}})
nlp.tokenizer.from_bytes(tokenizer_bytes)
assert tokenizer_bytes == nlp.tokenizer.to_bytes()
with make_tempdir() as d:
file_path = d / "tokenizer"
zh_tokenizer.to_disk(file_path)
nlp = Chinese(meta={"tokenizer": {"config": {"use_jieba": False}}})
nlp.tokenizer.from_disk(file_path)
assert tokenizer_bytes == nlp.tokenizer.to_bytes()
def test_zh_tokenizer_serialize_char(zh_tokenizer_char):
zh_tokenizer_serialize(zh_tokenizer_char)
def test_zh_tokenizer_serialize_jieba(zh_tokenizer_jieba):
zh_tokenizer_serialize(zh_tokenizer_jieba)
def test_zh_tokenizer_serialize_pkuseg(zh_tokenizer_pkuseg):
zh_tokenizer_serialize(zh_tokenizer_pkuseg)
@pytest.mark.slow
def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
nlp = Chinese(meta={"tokenizer": {"config": {"use_jieba": False, "use_pkuseg": True, "pkuseg_model": "medicine"}}})
zh_tokenizer_serialize(nlp.tokenizer)

View File

@ -19,7 +19,7 @@ import pytest
(",", False),
],
)
def test_lex_attrs_like_number(zh_tokenizer, text, match):
tokens = zh_tokenizer(text)
def test_lex_attrs_like_number(zh_tokenizer_jieba, text, match):
tokens = zh_tokenizer_jieba(text)
assert len(tokens) == 1
assert tokens[0].like_num == match

View File

@ -5,27 +5,41 @@ import pytest
# fmt: off
TOKENIZER_TESTS = [
("作为语言而言,为世界使用人数最多的语言,目前世界有五分之一人口做为母语。",
TEXTS = ("作为语言而言,为世界使用人数最多的语言,目前世界有五分之一人口做为母语。",)
JIEBA_TOKENIZER_TESTS = [
(TEXTS[0],
['作为', '语言', '而言', '', '', '世界', '使用', '', '数最多',
'', '语言', '', '目前', '世界', '', '五分之一', '人口', '',
'', '母语', '']),
]
PKUSEG_TOKENIZER_TESTS = [
(TEXTS[0],
['作为', '语言', '而言', '', '', '世界', '使用', '人数', '最多',
'', '语言', '', '目前', '世界', '', '五分之一', '人口', '做为',
'母语', '']),
]
# fmt: on
@pytest.mark.parametrize("text,expected_tokens", TOKENIZER_TESTS)
def test_zh_tokenizer(zh_tokenizer, text, expected_tokens):
zh_tokenizer.use_jieba = False
tokens = [token.text for token in zh_tokenizer(text)]
@pytest.mark.parametrize("text", TEXTS)
def test_zh_tokenizer_char(zh_tokenizer_char, text):
tokens = [token.text for token in zh_tokenizer_char(text)]
assert tokens == list(text)
zh_tokenizer.use_jieba = True
tokens = [token.text for token in zh_tokenizer(text)]
@pytest.mark.parametrize("text,expected_tokens", JIEBA_TOKENIZER_TESTS)
def test_zh_tokenizer_jieba(zh_tokenizer_jieba, text, expected_tokens):
tokens = [token.text for token in zh_tokenizer_jieba(text)]
assert tokens == expected_tokens
def test_extra_spaces(zh_tokenizer):
@pytest.mark.parametrize("text,expected_tokens", PKUSEG_TOKENIZER_TESTS)
def test_zh_tokenizer_pkuseg(zh_tokenizer_pkuseg, text, expected_tokens):
tokens = [token.text for token in zh_tokenizer_pkuseg(text)]
assert tokens == expected_tokens
def test_extra_spaces(zh_tokenizer_char):
# note: three spaces after "I"
tokens = zh_tokenizer("I like cheese.")
tokens = zh_tokenizer_char("I like cheese.")
assert tokens[1].orth_ == " "