mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Update ChineseTokenizer
* Allow `pkuseg_model` to be set to `None` on initialization * Don't save config within tokenizer * Force convert pkuseg_model to use pickle protocol 4 by reencoding with `pickle5` on serialization * Update pkuseg serialization test
This commit is contained in:
parent
3838b14148
commit
11e195d3ed
|
@ -670,10 +670,15 @@ class Errors:
|
||||||
"'{token_attrs}'.")
|
"'{token_attrs}'.")
|
||||||
E999 = ("Unable to merge the `Doc` objects because they do not all share "
|
E999 = ("Unable to merge the `Doc` objects because they do not all share "
|
||||||
"the same `Vocab`.")
|
"the same `Vocab`.")
|
||||||
E1000 = ("No pkuseg model available. Provide a pkuseg model when "
|
E1000 = ("The Chinese word segmenter is pkuseg but no pkuseg model was "
|
||||||
"initializing the pipeline:\n"
|
"specified. Provide the name of a pretrained model or the path to "
|
||||||
'cfg = {"tokenizer": {"segmenter": "pkuseg", "pkuseg_model": name_or_path}}\n'
|
"a model when initializing the pipeline:\n"
|
||||||
'nlp = Chinese(config=cfg)')
|
'config = {\n'
|
||||||
|
' "@tokenizers": "spacy.zh.ChineseTokenizer",\n'
|
||||||
|
' "segmenter": "pkuseg",\n'
|
||||||
|
' "pkuseg_model": "default", # or "/path/to/pkuseg_model" \n'
|
||||||
|
'}\n'
|
||||||
|
'nlp = Chinese.from_config({"nlp": {"tokenizer": config}})')
|
||||||
E1001 = ("Target token outside of matched span for match with tokens "
|
E1001 = ("Target token outside of matched span for match with tokens "
|
||||||
"'{span}' and offset '{index}' matched by patterns '{patterns}'.")
|
"'{span}' and offset '{index}' matched by patterns '{patterns}'.")
|
||||||
E1002 = ("Span index out of range.")
|
E1002 = ("Span index out of range.")
|
||||||
|
|
|
@ -15,7 +15,8 @@ from .stop_words import STOP_WORDS
|
||||||
from ... import util
|
from ... import util
|
||||||
|
|
||||||
|
|
||||||
_PKUSEG_INSTALL_MSG = "install it with `pip install pkuseg==0.0.25` or from https://github.com/lancopku/pkuseg-python"
|
_PKUSEG_INSTALL_MSG = "install pkuseg and pickle5 with `pip install pkuseg==0.0.25 pickle5`"
|
||||||
|
_PKUSEG_PICKLE_WARNING = "Failed to force pkuseg model to use pickle protocol 4. If you're saving this model with python 3.8, it may not work with python 3.6-3.7.
|
||||||
|
|
||||||
DEFAULT_CONFIG = """
|
DEFAULT_CONFIG = """
|
||||||
[nlp]
|
[nlp]
|
||||||
|
@ -64,7 +65,7 @@ class ChineseTokenizer(DummyTokenizer):
|
||||||
pkuseg_user_dict: Optional[str] = None,
|
pkuseg_user_dict: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.vocab = nlp.vocab
|
self.vocab = nlp.vocab
|
||||||
if isinstance(segmenter, Segmenter): # we might have the Enum here
|
if isinstance(segmenter, Segmenter):
|
||||||
segmenter = segmenter.value
|
segmenter = segmenter.value
|
||||||
self.segmenter = segmenter
|
self.segmenter = segmenter
|
||||||
self.pkuseg_model = pkuseg_model
|
self.pkuseg_model = pkuseg_model
|
||||||
|
@ -136,18 +137,6 @@ class ChineseTokenizer(DummyTokenizer):
|
||||||
warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter)
|
warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter)
|
||||||
warnings.warn(warn_msg)
|
warnings.warn(warn_msg)
|
||||||
|
|
||||||
def _get_config(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"segmenter": self.segmenter,
|
|
||||||
"pkuseg_model": self.pkuseg_model,
|
|
||||||
"pkuseg_user_dict": self.pkuseg_user_dict,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _set_config(self, config: Dict[str, Any] = {}) -> None:
|
|
||||||
self.segmenter = config.get("segmenter", Segmenter.char)
|
|
||||||
self.pkuseg_model = config.get("pkuseg_model", None)
|
|
||||||
self.pkuseg_user_dict = config.get("pkuseg_user_dict", "default")
|
|
||||||
|
|
||||||
def to_bytes(self, **kwargs):
|
def to_bytes(self, **kwargs):
|
||||||
pkuseg_features_b = b""
|
pkuseg_features_b = b""
|
||||||
pkuseg_weights_b = b""
|
pkuseg_weights_b = b""
|
||||||
|
@ -157,6 +146,20 @@ class ChineseTokenizer(DummyTokenizer):
|
||||||
self.pkuseg_seg.feature_extractor.save(tempdir)
|
self.pkuseg_seg.feature_extractor.save(tempdir)
|
||||||
self.pkuseg_seg.model.save(tempdir)
|
self.pkuseg_seg.model.save(tempdir)
|
||||||
tempdir = Path(tempdir)
|
tempdir = Path(tempdir)
|
||||||
|
# pkuseg saves features.pkl with pickle.HIGHEST_PROTOCOL, which
|
||||||
|
# means that it will be saved with pickle protocol 5 with
|
||||||
|
# python 3.8, which can't be reloaded with python 3.6-3.7.
|
||||||
|
# To try to make the model compatible with python 3.6+, reload
|
||||||
|
# the data with pickle5 and convert it back to protocol 4.
|
||||||
|
try:
|
||||||
|
import pickle5
|
||||||
|
|
||||||
|
with open(tempdir / "features.pkl", "rb") as fileh:
|
||||||
|
features = pickle5.load(fileh)
|
||||||
|
with open(tempdir / "features.pkl", "wb") as fileh:
|
||||||
|
pickle5.dump(features, fileh, protocol=4)
|
||||||
|
except:
|
||||||
|
warnings.warn(_PKUSEG_PICKLE_WARNING)
|
||||||
with open(tempdir / "features.pkl", "rb") as fileh:
|
with open(tempdir / "features.pkl", "rb") as fileh:
|
||||||
pkuseg_features_b = fileh.read()
|
pkuseg_features_b = fileh.read()
|
||||||
with open(tempdir / "weights.npz", "rb") as fileh:
|
with open(tempdir / "weights.npz", "rb") as fileh:
|
||||||
|
@ -168,7 +171,6 @@ class ChineseTokenizer(DummyTokenizer):
|
||||||
sorted(list(self.pkuseg_seg.postprocesser.other_words)),
|
sorted(list(self.pkuseg_seg.postprocesser.other_words)),
|
||||||
)
|
)
|
||||||
serializers = {
|
serializers = {
|
||||||
"cfg": lambda: srsly.json_dumps(self._get_config()),
|
|
||||||
"pkuseg_features": lambda: pkuseg_features_b,
|
"pkuseg_features": lambda: pkuseg_features_b,
|
||||||
"pkuseg_weights": lambda: pkuseg_weights_b,
|
"pkuseg_weights": lambda: pkuseg_weights_b,
|
||||||
"pkuseg_processors": lambda: srsly.msgpack_dumps(pkuseg_processors_data),
|
"pkuseg_processors": lambda: srsly.msgpack_dumps(pkuseg_processors_data),
|
||||||
|
@ -188,7 +190,6 @@ class ChineseTokenizer(DummyTokenizer):
|
||||||
pkuseg_data["processors_data"] = srsly.msgpack_loads(b)
|
pkuseg_data["processors_data"] = srsly.msgpack_loads(b)
|
||||||
|
|
||||||
deserializers = {
|
deserializers = {
|
||||||
"cfg": lambda b: self._set_config(srsly.json_loads(b)),
|
|
||||||
"pkuseg_features": deserialize_pkuseg_features,
|
"pkuseg_features": deserialize_pkuseg_features,
|
||||||
"pkuseg_weights": deserialize_pkuseg_weights,
|
"pkuseg_weights": deserialize_pkuseg_weights,
|
||||||
"pkuseg_processors": deserialize_pkuseg_processors,
|
"pkuseg_processors": deserialize_pkuseg_processors,
|
||||||
|
@ -229,6 +230,16 @@ class ChineseTokenizer(DummyTokenizer):
|
||||||
path.mkdir(parents=True)
|
path.mkdir(parents=True)
|
||||||
self.pkuseg_seg.model.save(path)
|
self.pkuseg_seg.model.save(path)
|
||||||
self.pkuseg_seg.feature_extractor.save(path)
|
self.pkuseg_seg.feature_extractor.save(path)
|
||||||
|
# try to convert features.pkl to pickle protocol 4
|
||||||
|
try:
|
||||||
|
import pickle5
|
||||||
|
|
||||||
|
with open(path / "features.pkl", "rb") as fileh:
|
||||||
|
features = pickle5.load(fileh)
|
||||||
|
with open(path / "features.pkl", "wb") as fileh:
|
||||||
|
pickle5.dump(features, fileh, protocol=4)
|
||||||
|
except:
|
||||||
|
warnings.warn(_PKUSEG_PICKLE_WARNING)
|
||||||
|
|
||||||
def save_pkuseg_processors(path):
|
def save_pkuseg_processors(path):
|
||||||
if self.pkuseg_seg:
|
if self.pkuseg_seg:
|
||||||
|
@ -241,7 +252,6 @@ class ChineseTokenizer(DummyTokenizer):
|
||||||
srsly.write_msgpack(path, data)
|
srsly.write_msgpack(path, data)
|
||||||
|
|
||||||
serializers = {
|
serializers = {
|
||||||
"cfg": lambda p: srsly.write_json(p, self._get_config()),
|
|
||||||
"pkuseg_model": lambda p: save_pkuseg_model(p),
|
"pkuseg_model": lambda p: save_pkuseg_model(p),
|
||||||
"pkuseg_processors": lambda p: save_pkuseg_processors(p),
|
"pkuseg_processors": lambda p: save_pkuseg_processors(p),
|
||||||
}
|
}
|
||||||
|
@ -277,7 +287,6 @@ class ChineseTokenizer(DummyTokenizer):
|
||||||
self.pkuseg_seg.postprocesser.other_words = set(other_words)
|
self.pkuseg_seg.postprocesser.other_words = set(other_words)
|
||||||
|
|
||||||
serializers = {
|
serializers = {
|
||||||
"cfg": lambda p: self._set_config(srsly.read_json(p)),
|
|
||||||
"pkuseg_model": lambda p: load_pkuseg_model(p),
|
"pkuseg_model": lambda p: load_pkuseg_model(p),
|
||||||
"pkuseg_processors": lambda p: load_pkuseg_processors(p),
|
"pkuseg_processors": lambda p: load_pkuseg_processors(p),
|
||||||
}
|
}
|
||||||
|
@ -314,21 +323,14 @@ def try_jieba_import(segmenter: str) -> None:
|
||||||
raise ImportError(msg) from None
|
raise ImportError(msg) from None
|
||||||
|
|
||||||
|
|
||||||
def try_pkuseg_import(segmenter: str, pkuseg_model: str, pkuseg_user_dict: str) -> None:
|
def try_pkuseg_import(segmenter: str, pkuseg_model: Optional[str], pkuseg_user_dict: str) -> None:
|
||||||
try:
|
try:
|
||||||
import pkuseg
|
import pkuseg
|
||||||
|
|
||||||
if pkuseg_model:
|
if pkuseg_model is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
return pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict)
|
return pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict)
|
||||||
elif segmenter == Segmenter.pkuseg:
|
|
||||||
msg = (
|
|
||||||
"The Chinese word segmenter is 'pkuseg' but no pkuseg model "
|
|
||||||
"was specified. Please provide the name of a pretrained model "
|
|
||||||
"or the path to a model with:\n"
|
|
||||||
'cfg = {"nlp": {"tokenizer": {"segmenter": "pkuseg", "pkuseg_model": name_or_path }}\n'
|
|
||||||
"nlp = Chinese.from_config(cfg)"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
if segmenter == Segmenter.pkuseg:
|
if segmenter == Segmenter.pkuseg:
|
||||||
msg = "pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG
|
msg = "pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG
|
||||||
|
|
|
@ -27,9 +27,10 @@ def test_zh_tokenizer_serialize_jieba(zh_tokenizer_jieba):
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
|
def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
|
||||||
nlp = Chinese(
|
config = {
|
||||||
meta={
|
"@tokenizers": "spacy.zh.ChineseTokenizer",
|
||||||
"tokenizer": {"config": {"segmenter": "pkuseg", "pkuseg_model": "medicine"}}
|
"segmenter": "pkuseg",
|
||||||
|
"pkuseg_model": "medicine",
|
||||||
}
|
}
|
||||||
)
|
nlp = Chinese.from_config({"nlp": {"tokenizer": config}})
|
||||||
zh_tokenizer_serialize(nlp.tokenizer)
|
zh_tokenizer_serialize(nlp.tokenizer)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user