mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 20:28:20 +03:00
320 lines
12 KiB
Python
320 lines
12 KiB
Python
# coding: utf8
|
|
from __future__ import unicode_literals
|
|
|
|
import tempfile
|
|
import srsly
|
|
from pathlib import Path
|
|
from collections import OrderedDict
|
|
from ...attrs import LANG
|
|
from ...language import Language
|
|
from ...tokens import Doc
|
|
from ...util import DummyTokenizer
|
|
from ..tokenizer_exceptions import BASE_EXCEPTIONS
|
|
from .lex_attrs import LEX_ATTRS
|
|
from .stop_words import STOP_WORDS
|
|
from .tag_map import TAG_MAP
|
|
from ... import util
|
|
|
|
|
|
_PKUSEG_INSTALL_MSG = "install it with `pip install pkuseg==0.0.22` or from https://github.com/lancopku/pkuseg-python"
|
|
|
|
|
|
def try_jieba_import(use_jieba):
|
|
try:
|
|
import jieba
|
|
|
|
# segment a short text to have jieba initialize its cache in advance
|
|
list(jieba.cut("作为", cut_all=False))
|
|
|
|
return jieba
|
|
except ImportError:
|
|
if use_jieba:
|
|
msg = (
|
|
"Jieba not installed. Either set the default to False with "
|
|
"`from spacy.lang.zh import ChineseDefaults; ChineseDefaults.use_jieba = False`, "
|
|
"or install it with `pip install jieba` or from "
|
|
"https://github.com/fxsjy/jieba"
|
|
)
|
|
raise ImportError(msg)
|
|
|
|
|
|
def try_pkuseg_import(use_pkuseg, pkuseg_model, pkuseg_user_dict):
|
|
try:
|
|
import pkuseg
|
|
|
|
if pkuseg_model:
|
|
return pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict)
|
|
elif use_pkuseg:
|
|
msg = (
|
|
"Chinese.use_pkuseg is True but no pkuseg model was specified. "
|
|
"Please provide the name of a pretrained model "
|
|
"or the path to a model with "
|
|
'`Chinese(meta={"tokenizer": {"config": {"pkuseg_model": name_or_path}}}).'
|
|
)
|
|
raise ValueError(msg)
|
|
except ImportError:
|
|
if use_pkuseg:
|
|
msg = (
|
|
"pkuseg not installed. Either set Chinese.use_pkuseg = False, "
|
|
"or " + _PKUSEG_INSTALL_MSG
|
|
)
|
|
raise ImportError(msg)
|
|
except FileNotFoundError:
|
|
if use_pkuseg:
|
|
msg = "Unable to load pkuseg model from: " + pkuseg_model
|
|
raise FileNotFoundError(msg)
|
|
|
|
|
|
class ChineseTokenizer(DummyTokenizer):
|
|
def __init__(self, cls, nlp=None, config={}):
|
|
self.use_jieba = config.get("use_jieba", cls.use_jieba)
|
|
self.use_pkuseg = config.get("use_pkuseg", cls.use_pkuseg)
|
|
self.require_pkuseg = config.get("require_pkuseg", False)
|
|
self.vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp)
|
|
self.jieba_seg = try_jieba_import(self.use_jieba)
|
|
self.pkuseg_seg = try_pkuseg_import(
|
|
self.use_pkuseg,
|
|
pkuseg_model=config.get("pkuseg_model", None),
|
|
pkuseg_user_dict=config.get("pkuseg_user_dict", "default"),
|
|
)
|
|
# remove relevant settings from config so they're not also saved in
|
|
# Language.meta
|
|
for key in ["use_jieba", "use_pkuseg", "require_pkuseg", "pkuseg_model"]:
|
|
if key in config:
|
|
del config[key]
|
|
self.tokenizer = Language.Defaults().create_tokenizer(nlp)
|
|
|
|
def __call__(self, text):
|
|
use_jieba = self.use_jieba
|
|
use_pkuseg = self.use_pkuseg
|
|
if self.require_pkuseg:
|
|
use_jieba = False
|
|
use_pkuseg = True
|
|
if use_jieba:
|
|
words = list([x for x in self.jieba_seg.cut(text, cut_all=False) if x])
|
|
(words, spaces) = util.get_words_and_spaces(words, text)
|
|
return Doc(self.vocab, words=words, spaces=spaces)
|
|
elif use_pkuseg:
|
|
words = self.pkuseg_seg.cut(text)
|
|
(words, spaces) = util.get_words_and_spaces(words, text)
|
|
return Doc(self.vocab, words=words, spaces=spaces)
|
|
else:
|
|
# split into individual characters
|
|
words = list(text)
|
|
(words, spaces) = util.get_words_and_spaces(words, text)
|
|
return Doc(self.vocab, words=words, spaces=spaces)
|
|
|
|
def pkuseg_update_user_dict(self, words, reset=False):
|
|
if self.pkuseg_seg:
|
|
if reset:
|
|
try:
|
|
import pkuseg
|
|
|
|
self.pkuseg_seg.preprocesser = pkuseg.Preprocesser(None)
|
|
except ImportError:
|
|
if self.use_pkuseg:
|
|
msg = (
|
|
"pkuseg not installed: unable to reset pkuseg "
|
|
"user dict. Please " + _PKUSEG_INSTALL_MSG
|
|
)
|
|
raise ImportError(msg)
|
|
for word in words:
|
|
self.pkuseg_seg.preprocesser.insert(word.strip(), "")
|
|
|
|
def _get_config(self):
|
|
config = OrderedDict(
|
|
(
|
|
("use_jieba", self.use_jieba),
|
|
("use_pkuseg", self.use_pkuseg),
|
|
("require_pkuseg", self.require_pkuseg),
|
|
)
|
|
)
|
|
return config
|
|
|
|
def _set_config(self, config={}):
|
|
self.use_jieba = config.get("use_jieba", False)
|
|
self.use_pkuseg = config.get("use_pkuseg", False)
|
|
self.require_pkuseg = config.get("require_pkuseg", False)
|
|
|
|
def to_bytes(self, **kwargs):
|
|
pkuseg_features_b = b""
|
|
pkuseg_weights_b = b""
|
|
pkuseg_processors_data = None
|
|
if self.pkuseg_seg:
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
self.pkuseg_seg.feature_extractor.save(tempdir)
|
|
self.pkuseg_seg.model.save(tempdir)
|
|
tempdir = Path(tempdir)
|
|
with open(tempdir / "features.pkl", "rb") as fileh:
|
|
pkuseg_features_b = fileh.read()
|
|
with open(tempdir / "weights.npz", "rb") as fileh:
|
|
pkuseg_weights_b = fileh.read()
|
|
pkuseg_processors_data = (
|
|
_get_pkuseg_trie_data(self.pkuseg_seg.preprocesser.trie),
|
|
self.pkuseg_seg.postprocesser.do_process,
|
|
sorted(list(self.pkuseg_seg.postprocesser.common_words)),
|
|
sorted(list(self.pkuseg_seg.postprocesser.other_words)),
|
|
)
|
|
serializers = OrderedDict(
|
|
(
|
|
("cfg", lambda: srsly.json_dumps(self._get_config())),
|
|
("pkuseg_features", lambda: pkuseg_features_b),
|
|
("pkuseg_weights", lambda: pkuseg_weights_b),
|
|
(
|
|
"pkuseg_processors",
|
|
lambda: srsly.msgpack_dumps(pkuseg_processors_data),
|
|
),
|
|
)
|
|
)
|
|
return util.to_bytes(serializers, [])
|
|
|
|
def from_bytes(self, data, **kwargs):
|
|
pkuseg_data = {"features_b": b"", "weights_b": b"", "processors_data": None}
|
|
|
|
def deserialize_pkuseg_features(b):
|
|
pkuseg_data["features_b"] = b
|
|
|
|
def deserialize_pkuseg_weights(b):
|
|
pkuseg_data["weights_b"] = b
|
|
|
|
def deserialize_pkuseg_processors(b):
|
|
pkuseg_data["processors_data"] = srsly.msgpack_loads(b)
|
|
|
|
deserializers = OrderedDict(
|
|
(
|
|
("cfg", lambda b: self._set_config(srsly.json_loads(b))),
|
|
("pkuseg_features", deserialize_pkuseg_features),
|
|
("pkuseg_weights", deserialize_pkuseg_weights),
|
|
("pkuseg_processors", deserialize_pkuseg_processors),
|
|
)
|
|
)
|
|
util.from_bytes(data, deserializers, [])
|
|
|
|
if pkuseg_data["features_b"] and pkuseg_data["weights_b"]:
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
tempdir = Path(tempdir)
|
|
with open(tempdir / "features.pkl", "wb") as fileh:
|
|
fileh.write(pkuseg_data["features_b"])
|
|
with open(tempdir / "weights.npz", "wb") as fileh:
|
|
fileh.write(pkuseg_data["weights_b"])
|
|
try:
|
|
import pkuseg
|
|
except ImportError:
|
|
raise ImportError(
|
|
"pkuseg not installed. To use this model, "
|
|
+ _PKUSEG_INSTALL_MSG
|
|
)
|
|
self.pkuseg_seg = pkuseg.pkuseg(str(tempdir))
|
|
if pkuseg_data["processors_data"]:
|
|
processors_data = pkuseg_data["processors_data"]
|
|
(user_dict, do_process, common_words, other_words) = processors_data
|
|
self.pkuseg_seg.preprocesser = pkuseg.Preprocesser(user_dict)
|
|
self.pkuseg_seg.postprocesser.do_process = do_process
|
|
self.pkuseg_seg.postprocesser.common_words = set(common_words)
|
|
self.pkuseg_seg.postprocesser.other_words = set(other_words)
|
|
|
|
return self
|
|
|
|
def to_disk(self, path, **kwargs):
|
|
path = util.ensure_path(path)
|
|
|
|
def save_pkuseg_model(path):
|
|
if self.pkuseg_seg:
|
|
if not path.exists():
|
|
path.mkdir(parents=True)
|
|
self.pkuseg_seg.model.save(path)
|
|
self.pkuseg_seg.feature_extractor.save(path)
|
|
|
|
def save_pkuseg_processors(path):
|
|
if self.pkuseg_seg:
|
|
data = (
|
|
_get_pkuseg_trie_data(self.pkuseg_seg.preprocesser.trie),
|
|
self.pkuseg_seg.postprocesser.do_process,
|
|
sorted(list(self.pkuseg_seg.postprocesser.common_words)),
|
|
sorted(list(self.pkuseg_seg.postprocesser.other_words)),
|
|
)
|
|
srsly.write_msgpack(path, data)
|
|
|
|
serializers = OrderedDict(
|
|
(
|
|
("cfg", lambda p: srsly.write_json(p, self._get_config())),
|
|
("pkuseg_model", lambda p: save_pkuseg_model(p)),
|
|
("pkuseg_processors", lambda p: save_pkuseg_processors(p)),
|
|
)
|
|
)
|
|
return util.to_disk(path, serializers, [])
|
|
|
|
def from_disk(self, path, **kwargs):
|
|
path = util.ensure_path(path)
|
|
|
|
def load_pkuseg_model(path):
|
|
try:
|
|
import pkuseg
|
|
except ImportError:
|
|
if self.use_pkuseg:
|
|
raise ImportError(
|
|
"pkuseg not installed. To use this model, "
|
|
+ _PKUSEG_INSTALL_MSG
|
|
)
|
|
if path.exists():
|
|
self.pkuseg_seg = pkuseg.pkuseg(path)
|
|
|
|
def load_pkuseg_processors(path):
|
|
try:
|
|
import pkuseg
|
|
except ImportError:
|
|
if self.use_pkuseg:
|
|
raise ImportError(self._pkuseg_install_msg)
|
|
if self.pkuseg_seg:
|
|
data = srsly.read_msgpack(path)
|
|
(user_dict, do_process, common_words, other_words) = data
|
|
self.pkuseg_seg.preprocesser = pkuseg.Preprocesser(user_dict)
|
|
self.pkuseg_seg.postprocesser.do_process = do_process
|
|
self.pkuseg_seg.postprocesser.common_words = set(common_words)
|
|
self.pkuseg_seg.postprocesser.other_words = set(other_words)
|
|
|
|
serializers = OrderedDict(
|
|
(
|
|
("cfg", lambda p: self._set_config(srsly.read_json(p))),
|
|
("pkuseg_model", lambda p: load_pkuseg_model(p)),
|
|
("pkuseg_processors", lambda p: load_pkuseg_processors(p)),
|
|
)
|
|
)
|
|
util.from_disk(path, serializers, [])
|
|
|
|
|
|
class ChineseDefaults(Language.Defaults):
|
|
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
|
|
lex_attr_getters.update(LEX_ATTRS)
|
|
lex_attr_getters[LANG] = lambda text: "zh"
|
|
tokenizer_exceptions = BASE_EXCEPTIONS
|
|
stop_words = STOP_WORDS
|
|
tag_map = TAG_MAP
|
|
writing_system = {"direction": "ltr", "has_case": False, "has_letters": False}
|
|
use_jieba = True
|
|
use_pkuseg = False
|
|
|
|
@classmethod
|
|
def create_tokenizer(cls, nlp=None, config={}):
|
|
return ChineseTokenizer(cls, nlp, config=config)
|
|
|
|
|
|
class Chinese(Language):
|
|
lang = "zh"
|
|
Defaults = ChineseDefaults # override defaults
|
|
|
|
def make_doc(self, text):
|
|
return self.tokenizer(text)
|
|
|
|
|
|
def _get_pkuseg_trie_data(node, path=""):
|
|
data = []
|
|
for c, child_node in sorted(node.children.items()):
|
|
data.extend(_get_pkuseg_trie_data(child_node, path + c))
|
|
if node.isword:
|
|
data.append((path, node.usertag))
|
|
return data
|
|
|
|
|
|
__all__ = ["Chinese"]
|