mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Update Japanese tokenizer config and add serialization (#5562)
* Use `config` dict for tokenizer settings * Add serialization of split mode setting * Add tests for tokenizer split modes and serialization of split mode setting Based on #5561
This commit is contained in:
parent
456bf47f51
commit
3bf111585d
|
@ -1,7 +1,8 @@
|
|||
# encoding: utf8
|
||||
from __future__ import unicode_literals, print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import srsly
|
||||
from collections import namedtuple, OrderedDict
|
||||
|
||||
from .stop_words import STOP_WORDS
|
||||
from .syntax_iterators import SYNTAX_ITERATORS
|
||||
|
@ -10,12 +11,13 @@ from .tag_orth_map import TAG_ORTH_MAP
|
|||
from .tag_bigram_map import TAG_BIGRAM_MAP
|
||||
from ...attrs import LANG
|
||||
from ...compat import copy_reg
|
||||
from ...errors import Errors
|
||||
from ...language import Language
|
||||
from ...symbols import POS
|
||||
from ...tokens import Doc
|
||||
from ...util import DummyTokenizer
|
||||
from ... import util
|
||||
|
||||
from ...errors import Errors
|
||||
|
||||
# Hold the attributes we need with convenient names
|
||||
DetailedToken = namedtuple("DetailedToken", ["surface", "pos", "lemma"])
|
||||
|
@ -26,14 +28,20 @@ DummyNode = namedtuple("DummyNode", ["surface", "pos", "lemma"])
|
|||
DummySpace = DummyNode(" ", " ", " ")
|
||||
|
||||
|
||||
def try_sudachi_import():
|
||||
def try_sudachi_import(split_mode="A"):
|
||||
"""SudachiPy is required for Japanese support, so check for it.
|
||||
It it's not available blow up and explain how to fix it."""
|
||||
It it's not available blow up and explain how to fix it.
|
||||
split_mode should be one of these values: "A", "B", "C", None->"A"."""
|
||||
try:
|
||||
from sudachipy import dictionary, tokenizer
|
||||
|
||||
split_mode = {
|
||||
None: tokenizer.Tokenizer.SplitMode.A,
|
||||
"A": tokenizer.Tokenizer.SplitMode.A,
|
||||
"B": tokenizer.Tokenizer.SplitMode.B,
|
||||
"C": tokenizer.Tokenizer.SplitMode.C,
|
||||
}[split_mode]
|
||||
tok = dictionary.Dictionary().create(
|
||||
mode=tokenizer.Tokenizer.SplitMode.A
|
||||
mode=split_mode
|
||||
)
|
||||
return tok
|
||||
except ImportError:
|
||||
|
@ -164,9 +172,10 @@ def get_words_lemmas_tags_spaces(dtokens, text, gap_tag=("空白", "")):
|
|||
|
||||
|
||||
class JapaneseTokenizer(DummyTokenizer):
|
||||
def __init__(self, cls, nlp=None):
|
||||
def __init__(self, cls, nlp=None, config={}):
|
||||
self.vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp)
|
||||
self.tokenizer = try_sudachi_import()
|
||||
self.split_mode = config.get("split_mode", None)
|
||||
self.tokenizer = try_sudachi_import(self.split_mode)
|
||||
|
||||
def __call__(self, text):
|
||||
dtokens = get_dtokens(self.tokenizer, text)
|
||||
|
@ -193,6 +202,54 @@ class JapaneseTokenizer(DummyTokenizer):
|
|||
separate_sentences(doc)
|
||||
return doc
|
||||
|
||||
def _get_config(self):
|
||||
config = OrderedDict(
|
||||
(
|
||||
("split_mode", self.split_mode),
|
||||
)
|
||||
)
|
||||
return config
|
||||
|
||||
def _set_config(self, config={}):
|
||||
self.split_mode = config.get("split_mode", None)
|
||||
|
||||
def to_bytes(self, **kwargs):
|
||||
serializers = OrderedDict(
|
||||
(
|
||||
("cfg", lambda: srsly.json_dumps(self._get_config())),
|
||||
)
|
||||
)
|
||||
return util.to_bytes(serializers, [])
|
||||
|
||||
def from_bytes(self, data, **kwargs):
|
||||
deserializers = OrderedDict(
|
||||
(
|
||||
("cfg", lambda b: self._set_config(srsly.json_loads(b))),
|
||||
)
|
||||
)
|
||||
util.from_bytes(data, deserializers, [])
|
||||
self.tokenizer = try_sudachi_import(self.split_mode)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, **kwargs):
|
||||
path = util.ensure_path(path)
|
||||
serializers = OrderedDict(
|
||||
(
|
||||
("cfg", lambda p: srsly.write_json(p, self._get_config())),
|
||||
)
|
||||
)
|
||||
return util.to_disk(path, serializers, [])
|
||||
|
||||
def from_disk(self, path, **kwargs):
|
||||
path = util.ensure_path(path)
|
||||
serializers = OrderedDict(
|
||||
(
|
||||
("cfg", lambda p: self._set_config(srsly.read_json(p))),
|
||||
)
|
||||
)
|
||||
util.from_disk(path, serializers, [])
|
||||
self.tokenizer = try_sudachi_import(self.split_mode)
|
||||
|
||||
|
||||
class JapaneseDefaults(Language.Defaults):
|
||||
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
|
||||
|
@ -203,8 +260,8 @@ class JapaneseDefaults(Language.Defaults):
|
|||
writing_system = {"direction": "ltr", "has_case": False, "has_letters": False}
|
||||
|
||||
@classmethod
|
||||
def create_tokenizer(cls, nlp=None):
|
||||
return JapaneseTokenizer(cls, nlp)
|
||||
def create_tokenizer(cls, nlp=None, config={}):
|
||||
return JapaneseTokenizer(cls, nlp, config)
|
||||
|
||||
|
||||
class Japanese(Language):
|
||||
|
|
37
spacy/tests/lang/ja/test_serialize.py
Normal file
37
spacy/tests/lang/ja/test_serialize.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy.lang.ja import Japanese
|
||||
from ...util import make_tempdir
|
||||
|
||||
|
||||
def test_ja_tokenizer_serialize(ja_tokenizer):
|
||||
tokenizer_bytes = ja_tokenizer.to_bytes()
|
||||
nlp = Japanese()
|
||||
nlp.tokenizer.from_bytes(tokenizer_bytes)
|
||||
assert tokenizer_bytes == nlp.tokenizer.to_bytes()
|
||||
assert nlp.tokenizer.split_mode == None
|
||||
|
||||
with make_tempdir() as d:
|
||||
file_path = d / "tokenizer"
|
||||
ja_tokenizer.to_disk(file_path)
|
||||
nlp = Japanese()
|
||||
nlp.tokenizer.from_disk(file_path)
|
||||
assert tokenizer_bytes == nlp.tokenizer.to_bytes()
|
||||
assert nlp.tokenizer.split_mode == None
|
||||
|
||||
# split mode is (de)serialized correctly
|
||||
nlp = Japanese(meta={"tokenizer": {"config": {"split_mode": "B"}}})
|
||||
nlp_r = Japanese()
|
||||
nlp_bytes = nlp.to_bytes()
|
||||
nlp_r.from_bytes(nlp_bytes)
|
||||
assert nlp_bytes == nlp_r.to_bytes()
|
||||
assert nlp_r.tokenizer.split_mode == "B"
|
||||
|
||||
with make_tempdir() as d:
|
||||
nlp.to_disk(d)
|
||||
nlp_r = Japanese()
|
||||
nlp_r.from_disk(d)
|
||||
assert nlp_bytes == nlp_r.to_bytes()
|
||||
assert nlp_r.tokenizer.split_mode == "B"
|
|
@ -3,6 +3,8 @@ from __future__ import unicode_literals
|
|||
|
||||
import pytest
|
||||
|
||||
from ...tokenizer.test_naughty_strings import NAUGHTY_STRINGS
|
||||
from spacy.lang.ja import Japanese
|
||||
|
||||
# fmt: off
|
||||
TOKENIZER_TESTS = [
|
||||
|
@ -55,21 +57,39 @@ def test_ja_tokenizer_pos(ja_tokenizer, text, expected_pos):
|
|||
pos = [token.pos_ for token in ja_tokenizer(text)]
|
||||
assert pos == expected_pos
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text,expected_sents", SENTENCE_TESTS)
|
||||
def test_ja_tokenizer_pos(ja_tokenizer, text, expected_sents):
|
||||
sents = [str(sent) for sent in ja_tokenizer(text).sents]
|
||||
assert sents == expected_sents
|
||||
|
||||
|
||||
def test_extra_spaces(ja_tokenizer):
|
||||
def test_ja_tokenizer_extra_spaces(ja_tokenizer):
|
||||
# note: three spaces after "I"
|
||||
tokens = ja_tokenizer("I like cheese.")
|
||||
assert tokens[1].orth_ == " "
|
||||
|
||||
from ...tokenizer.test_naughty_strings import NAUGHTY_STRINGS
|
||||
|
||||
@pytest.mark.parametrize("text", NAUGHTY_STRINGS)
|
||||
def test_tokenizer_naughty_strings(ja_tokenizer, text):
|
||||
def test_ja_tokenizer_naughty_strings(ja_tokenizer, text):
|
||||
tokens = ja_tokenizer(text)
|
||||
assert tokens.text_with_ws == text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text,len_a,len_b,len_c",
|
||||
[
|
||||
("選挙管理委員会", 4, 3, 1),
|
||||
("客室乗務員", 3, 2, 1),
|
||||
("労働者協同組合", 4, 3, 1),
|
||||
("機能性食品", 3, 2, 1),
|
||||
]
|
||||
)
|
||||
def test_ja_tokenizer_split_modes(ja_tokenizer, text, len_a, len_b, len_c):
|
||||
nlp_a = Japanese(meta={"tokenizer": {"config": {"split_mode": "A"}}})
|
||||
nlp_b = Japanese(meta={"tokenizer": {"config": {"split_mode": "B"}}})
|
||||
nlp_c = Japanese(meta={"tokenizer": {"config": {"split_mode": "C"}}})
|
||||
|
||||
assert len(ja_tokenizer(text)) == len_a
|
||||
assert len(nlp_a(text)) == len_a
|
||||
assert len(nlp_b(text)) == len_b
|
||||
assert len(nlp_c(text)) == len_c
|
||||
|
|
Loading…
Reference in New Issue
Block a user