spaCy/spacy/lang/zh/__init__.py
Daniël de Kok e2b70df012
Configure isort to use the Black profile, recursively isort the spacy module (#12721)
* Use isort with Black profile

* isort all the things

* Fix import cycles as a result of import sorting

* Add DOCBIN_ALL_ATTRS type definition

* Add isort to requirements

* Remove isort from build dependencies check

* Typo
2023-06-14 17:48:41 +02:00

338 lines
12 KiB
Python

import tempfile
import warnings
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional
import srsly
from ... import util
from ...errors import Errors, Warnings
from ...language import BaseDefaults, Language
from ...scorer import Scorer
from ...tokens import Doc
from ...training import Example, validate_examples
from ...util import DummyTokenizer, load_config_from_str, registry
from ...vocab import Vocab
from .lex_attrs import LEX_ATTRS
from .stop_words import STOP_WORDS
# fmt: off
_PKUSEG_INSTALL_MSG = "install spacy-pkuseg with `pip install \"spacy-pkuseg>=0.0.27,<0.1.0\"` or `conda install -c conda-forge \"spacy-pkuseg>=0.0.27,<0.1.0\"`"
# fmt: on
DEFAULT_CONFIG = """
[nlp]
[nlp.tokenizer]
@tokenizers = "spacy.zh.ChineseTokenizer"
segmenter = "char"
[initialize]
[initialize.tokenizer]
pkuseg_model = null
pkuseg_user_dict = "default"
"""
class Segmenter(str, Enum):
char = "char"
jieba = "jieba"
pkuseg = "pkuseg"
@classmethod
def values(cls):
return list(cls.__members__.keys())
@registry.tokenizers("spacy.zh.ChineseTokenizer")
def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char):
def chinese_tokenizer_factory(nlp):
return ChineseTokenizer(nlp.vocab, segmenter=segmenter)
return chinese_tokenizer_factory
class ChineseTokenizer(DummyTokenizer):
def __init__(self, vocab: Vocab, segmenter: Segmenter = Segmenter.char):
self.vocab = vocab
self.segmenter = (
segmenter.value if isinstance(segmenter, Segmenter) else segmenter
)
self.pkuseg_seg = None
self.jieba_seg = None
if self.segmenter not in Segmenter.values():
warn_msg = Warnings.W103.format(
lang="Chinese",
segmenter=self.segmenter,
supported=", ".join(Segmenter.values()),
default="'char' (character segmentation)",
)
warnings.warn(warn_msg)
self.segmenter = Segmenter.char
if self.segmenter == Segmenter.jieba:
self.jieba_seg = try_jieba_import()
def initialize(
self,
get_examples: Optional[Callable[[], Iterable[Example]]] = None,
*,
nlp: Optional[Language] = None,
pkuseg_model: Optional[str] = None,
pkuseg_user_dict: Optional[str] = "default",
):
if self.segmenter == Segmenter.pkuseg:
if pkuseg_user_dict is None:
pkuseg_user_dict = pkuseg_model
self.pkuseg_seg = try_pkuseg_import(
pkuseg_model=pkuseg_model, pkuseg_user_dict=pkuseg_user_dict
)
def __call__(self, text: str) -> Doc:
if self.segmenter == Segmenter.jieba:
words = list([x for x in self.jieba_seg.cut(text, cut_all=False) if x]) # type: ignore[union-attr]
(words, spaces) = util.get_words_and_spaces(words, text)
return Doc(self.vocab, words=words, spaces=spaces)
elif self.segmenter == Segmenter.pkuseg:
if self.pkuseg_seg is None:
raise ValueError(Errors.E1000)
words = self.pkuseg_seg.cut(text)
(words, spaces) = util.get_words_and_spaces(words, text)
return Doc(self.vocab, words=words, spaces=spaces)
# warn if segmenter setting is not the only remaining option "char"
if self.segmenter != Segmenter.char:
warn_msg = Warnings.W103.format(
lang="Chinese",
segmenter=self.segmenter,
supported=", ".join(Segmenter.values()),
default="'char' (character segmentation)",
)
warnings.warn(warn_msg)
# split into individual characters
words = list(text)
(words, spaces) = util.get_words_and_spaces(words, text)
return Doc(self.vocab, words=words, spaces=spaces)
def pkuseg_update_user_dict(self, words: List[str], reset: bool = False):
if self.segmenter == Segmenter.pkuseg:
if reset:
try:
import spacy_pkuseg
self.pkuseg_seg.preprocesser = spacy_pkuseg.Preprocesser(None) # type: ignore[attr-defined]
except ImportError:
msg = (
"spacy_pkuseg not installed: unable to reset pkuseg "
"user dict. Please " + _PKUSEG_INSTALL_MSG
)
raise ImportError(msg) from None
for word in words:
self.pkuseg_seg.preprocesser.insert(word.strip(), "") # type: ignore[attr-defined]
else:
warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter)
warnings.warn(warn_msg)
def score(self, examples):
validate_examples(examples, "ChineseTokenizer.score")
return Scorer.score_tokenization(examples)
def _get_config(self) -> Dict[str, Any]:
return {
"segmenter": self.segmenter,
}
def _set_config(self, config: Dict[str, Any] = {}) -> None:
self.segmenter = config.get("segmenter", Segmenter.char)
def to_bytes(self, **kwargs):
pkuseg_features_b = b""
pkuseg_weights_b = b""
pkuseg_processors_data = None
if self.pkuseg_seg:
with tempfile.TemporaryDirectory() as tempdir:
self.pkuseg_seg.feature_extractor.save(tempdir)
self.pkuseg_seg.model.save(tempdir)
tempdir = Path(tempdir)
with open(tempdir / "features.msgpack", "rb") as fileh:
pkuseg_features_b = fileh.read()
with open(tempdir / "weights.npz", "rb") as fileh:
pkuseg_weights_b = fileh.read()
pkuseg_processors_data = (
_get_pkuseg_trie_data(self.pkuseg_seg.preprocesser.trie),
self.pkuseg_seg.postprocesser.do_process,
sorted(list(self.pkuseg_seg.postprocesser.common_words)),
sorted(list(self.pkuseg_seg.postprocesser.other_words)),
)
serializers = {
"cfg": lambda: srsly.json_dumps(self._get_config()),
"pkuseg_features": lambda: pkuseg_features_b,
"pkuseg_weights": lambda: pkuseg_weights_b,
"pkuseg_processors": lambda: srsly.msgpack_dumps(pkuseg_processors_data),
}
return util.to_bytes(serializers, [])
def from_bytes(self, data, **kwargs):
pkuseg_data = {"features_b": b"", "weights_b": b"", "processors_data": None}
def deserialize_pkuseg_features(b):
pkuseg_data["features_b"] = b
def deserialize_pkuseg_weights(b):
pkuseg_data["weights_b"] = b
def deserialize_pkuseg_processors(b):
pkuseg_data["processors_data"] = srsly.msgpack_loads(b)
deserializers = {
"cfg": lambda b: self._set_config(srsly.json_loads(b)),
"pkuseg_features": deserialize_pkuseg_features,
"pkuseg_weights": deserialize_pkuseg_weights,
"pkuseg_processors": deserialize_pkuseg_processors,
}
util.from_bytes(data, deserializers, [])
if pkuseg_data["features_b"] and pkuseg_data["weights_b"]:
with tempfile.TemporaryDirectory() as tempdir:
tempdir = Path(tempdir)
with open(tempdir / "features.msgpack", "wb") as fileh:
fileh.write(pkuseg_data["features_b"])
with open(tempdir / "weights.npz", "wb") as fileh:
fileh.write(pkuseg_data["weights_b"])
try:
import spacy_pkuseg
except ImportError:
raise ImportError(
"spacy-pkuseg not installed. To use this model, "
+ _PKUSEG_INSTALL_MSG
) from None
self.pkuseg_seg = spacy_pkuseg.pkuseg(str(tempdir))
if pkuseg_data["processors_data"]:
processors_data = pkuseg_data["processors_data"]
(user_dict, do_process, common_words, other_words) = processors_data
self.pkuseg_seg.preprocesser = spacy_pkuseg.Preprocesser(user_dict)
self.pkuseg_seg.postprocesser.do_process = do_process
self.pkuseg_seg.postprocesser.common_words = set(common_words)
self.pkuseg_seg.postprocesser.other_words = set(other_words)
return self
def to_disk(self, path, **kwargs):
path = util.ensure_path(path)
def save_pkuseg_model(path):
if self.pkuseg_seg:
if not path.exists():
path.mkdir(parents=True)
self.pkuseg_seg.model.save(path)
self.pkuseg_seg.feature_extractor.save(path)
def save_pkuseg_processors(path):
if self.pkuseg_seg:
data = (
_get_pkuseg_trie_data(self.pkuseg_seg.preprocesser.trie),
self.pkuseg_seg.postprocesser.do_process,
sorted(list(self.pkuseg_seg.postprocesser.common_words)),
sorted(list(self.pkuseg_seg.postprocesser.other_words)),
)
srsly.write_msgpack(path, data)
serializers = {
"cfg": lambda p: srsly.write_json(p, self._get_config()),
"pkuseg_model": lambda p: save_pkuseg_model(p),
"pkuseg_processors": lambda p: save_pkuseg_processors(p),
}
return util.to_disk(path, serializers, [])
def from_disk(self, path, **kwargs):
path = util.ensure_path(path)
def load_pkuseg_model(path):
try:
import spacy_pkuseg
except ImportError:
if self.segmenter == Segmenter.pkuseg:
raise ImportError(
"spacy-pkuseg not installed. To use this model, "
+ _PKUSEG_INSTALL_MSG
) from None
if path.exists():
self.pkuseg_seg = spacy_pkuseg.pkuseg(path)
def load_pkuseg_processors(path):
try:
import spacy_pkuseg
except ImportError:
if self.segmenter == Segmenter.pkuseg:
raise ImportError(self._pkuseg_install_msg) from None
if self.segmenter == Segmenter.pkuseg:
data = srsly.read_msgpack(path)
(user_dict, do_process, common_words, other_words) = data
self.pkuseg_seg.preprocesser = spacy_pkuseg.Preprocesser(user_dict)
self.pkuseg_seg.postprocesser.do_process = do_process
self.pkuseg_seg.postprocesser.common_words = set(common_words)
self.pkuseg_seg.postprocesser.other_words = set(other_words)
serializers = {
"cfg": lambda p: self._set_config(srsly.read_json(p)),
"pkuseg_model": lambda p: load_pkuseg_model(p),
"pkuseg_processors": lambda p: load_pkuseg_processors(p),
}
util.from_disk(path, serializers, [])
class ChineseDefaults(BaseDefaults):
config = load_config_from_str(DEFAULT_CONFIG)
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
writing_system = {"direction": "ltr", "has_case": False, "has_letters": False}
class Chinese(Language):
lang = "zh"
Defaults = ChineseDefaults
def try_jieba_import():
try:
import jieba
# segment a short text to have jieba initialize its cache in advance
list(jieba.cut("作为", cut_all=False))
return jieba
except ImportError:
msg = (
"Jieba not installed. To use jieba, install it with `pip "
" install jieba` or from https://github.com/fxsjy/jieba"
)
raise ImportError(msg) from None
def try_pkuseg_import(pkuseg_model: Optional[str], pkuseg_user_dict: Optional[str]):
try:
import spacy_pkuseg
except ImportError:
msg = "spacy-pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG
raise ImportError(msg) from None
try:
return spacy_pkuseg.pkuseg(pkuseg_model, user_dict=pkuseg_user_dict)
except FileNotFoundError:
msg = "Unable to load pkuseg model from: " + str(pkuseg_model or "")
raise FileNotFoundError(msg) from None
def _get_pkuseg_trie_data(node, path=""):
data = []
for c, child_node in sorted(node.children.items()):
data.extend(_get_pkuseg_trie_data(child_node, path + c))
if node.isword:
data.append((path, node.usertag))
return data
__all__ = ["Chinese"]