diff --git a/spacy/lang/ja/__init__.py b/spacy/lang/ja/__init__.py index a623c7bdd..294c6b38d 100644 --- a/spacy/lang/ja/__init__.py +++ b/spacy/lang/ja/__init__.py @@ -1,7 +1,8 @@ # encoding: utf8 from __future__ import unicode_literals, print_function -from collections import namedtuple +import srsly +from collections import namedtuple, OrderedDict from .stop_words import STOP_WORDS from .syntax_iterators import SYNTAX_ITERATORS @@ -10,12 +11,13 @@ from .tag_orth_map import TAG_ORTH_MAP from .tag_bigram_map import TAG_BIGRAM_MAP from ...attrs import LANG from ...compat import copy_reg +from ...errors import Errors from ...language import Language from ...symbols import POS from ...tokens import Doc from ...util import DummyTokenizer +from ... import util -from ...errors import Errors # Hold the attributes we need with convenient names DetailedToken = namedtuple("DetailedToken", ["surface", "pos", "lemma"]) @@ -26,14 +28,20 @@ DummyNode = namedtuple("DummyNode", ["surface", "pos", "lemma"]) DummySpace = DummyNode(" ", " ", " ") -def try_sudachi_import(): +def try_sudachi_import(split_mode="A"): """SudachiPy is required for Japanese support, so check for it. - It it's not available blow up and explain how to fix it.""" + It it's not available blow up and explain how to fix it. + split_mode should be one of these values: "A", "B", "C", None->"A".""" try: from sudachipy import dictionary, tokenizer - + split_mode = { + None: tokenizer.Tokenizer.SplitMode.A, + "A": tokenizer.Tokenizer.SplitMode.A, + "B": tokenizer.Tokenizer.SplitMode.B, + "C": tokenizer.Tokenizer.SplitMode.C, + }[split_mode] tok = dictionary.Dictionary().create( - mode=tokenizer.Tokenizer.SplitMode.A + mode=split_mode ) return tok except ImportError: @@ -164,9 +172,10 @@ def get_words_lemmas_tags_spaces(dtokens, text, gap_tag=("空白", "")): class JapaneseTokenizer(DummyTokenizer): - def __init__(self, cls, nlp=None): + def __init__(self, cls, nlp=None, config={}): self.vocab = nlp.vocab if nlp is not None else cls.create_vocab(nlp) - self.tokenizer = try_sudachi_import() + self.split_mode = config.get("split_mode", None) + self.tokenizer = try_sudachi_import(self.split_mode) def __call__(self, text): dtokens = get_dtokens(self.tokenizer, text) @@ -193,6 +202,54 @@ class JapaneseTokenizer(DummyTokenizer): separate_sentences(doc) return doc + def _get_config(self): + config = OrderedDict( + ( + ("split_mode", self.split_mode), + ) + ) + return config + + def _set_config(self, config={}): + self.split_mode = config.get("split_mode", None) + + def to_bytes(self, **kwargs): + serializers = OrderedDict( + ( + ("cfg", lambda: srsly.json_dumps(self._get_config())), + ) + ) + return util.to_bytes(serializers, []) + + def from_bytes(self, data, **kwargs): + deserializers = OrderedDict( + ( + ("cfg", lambda b: self._set_config(srsly.json_loads(b))), + ) + ) + util.from_bytes(data, deserializers, []) + self.tokenizer = try_sudachi_import(self.split_mode) + return self + + def to_disk(self, path, **kwargs): + path = util.ensure_path(path) + serializers = OrderedDict( + ( + ("cfg", lambda p: srsly.write_json(p, self._get_config())), + ) + ) + return util.to_disk(path, serializers, []) + + def from_disk(self, path, **kwargs): + path = util.ensure_path(path) + serializers = OrderedDict( + ( + ("cfg", lambda p: self._set_config(srsly.read_json(p))), + ) + ) + util.from_disk(path, serializers, []) + self.tokenizer = try_sudachi_import(self.split_mode) + class JapaneseDefaults(Language.Defaults): lex_attr_getters = dict(Language.Defaults.lex_attr_getters) @@ -203,8 +260,8 @@ class JapaneseDefaults(Language.Defaults): writing_system = {"direction": "ltr", "has_case": False, "has_letters": False} @classmethod - def create_tokenizer(cls, nlp=None): - return JapaneseTokenizer(cls, nlp) + def create_tokenizer(cls, nlp=None, config={}): + return JapaneseTokenizer(cls, nlp, config) class Japanese(Language): diff --git a/spacy/tests/lang/ja/test_serialize.py b/spacy/tests/lang/ja/test_serialize.py new file mode 100644 index 000000000..018e645bb --- /dev/null +++ b/spacy/tests/lang/ja/test_serialize.py @@ -0,0 +1,37 @@ +# coding: utf-8 +from __future__ import unicode_literals + +import pytest +from spacy.lang.ja import Japanese +from ...util import make_tempdir + + +def test_ja_tokenizer_serialize(ja_tokenizer): + tokenizer_bytes = ja_tokenizer.to_bytes() + nlp = Japanese() + nlp.tokenizer.from_bytes(tokenizer_bytes) + assert tokenizer_bytes == nlp.tokenizer.to_bytes() + assert nlp.tokenizer.split_mode == None + + with make_tempdir() as d: + file_path = d / "tokenizer" + ja_tokenizer.to_disk(file_path) + nlp = Japanese() + nlp.tokenizer.from_disk(file_path) + assert tokenizer_bytes == nlp.tokenizer.to_bytes() + assert nlp.tokenizer.split_mode == None + + # split mode is (de)serialized correctly + nlp = Japanese(meta={"tokenizer": {"config": {"split_mode": "B"}}}) + nlp_r = Japanese() + nlp_bytes = nlp.to_bytes() + nlp_r.from_bytes(nlp_bytes) + assert nlp_bytes == nlp_r.to_bytes() + assert nlp_r.tokenizer.split_mode == "B" + + with make_tempdir() as d: + nlp.to_disk(d) + nlp_r = Japanese() + nlp_r.from_disk(d) + assert nlp_bytes == nlp_r.to_bytes() + assert nlp_r.tokenizer.split_mode == "B" diff --git a/spacy/tests/lang/ja/test_tokenizer.py b/spacy/tests/lang/ja/test_tokenizer.py index 5213aed58..82c43fe4c 100644 --- a/spacy/tests/lang/ja/test_tokenizer.py +++ b/spacy/tests/lang/ja/test_tokenizer.py @@ -3,6 +3,8 @@ from __future__ import unicode_literals import pytest +from ...tokenizer.test_naughty_strings import NAUGHTY_STRINGS +from spacy.lang.ja import Japanese # fmt: off TOKENIZER_TESTS = [ @@ -55,21 +57,39 @@ def test_ja_tokenizer_pos(ja_tokenizer, text, expected_pos): pos = [token.pos_ for token in ja_tokenizer(text)] assert pos == expected_pos + @pytest.mark.parametrize("text,expected_sents", SENTENCE_TESTS) def test_ja_tokenizer_pos(ja_tokenizer, text, expected_sents): sents = [str(sent) for sent in ja_tokenizer(text).sents] assert sents == expected_sents -def test_extra_spaces(ja_tokenizer): +def test_ja_tokenizer_extra_spaces(ja_tokenizer): # note: three spaces after "I" tokens = ja_tokenizer("I like cheese.") assert tokens[1].orth_ == " " -from ...tokenizer.test_naughty_strings import NAUGHTY_STRINGS @pytest.mark.parametrize("text", NAUGHTY_STRINGS) -def test_tokenizer_naughty_strings(ja_tokenizer, text): +def test_ja_tokenizer_naughty_strings(ja_tokenizer, text): tokens = ja_tokenizer(text) assert tokens.text_with_ws == text + +@pytest.mark.parametrize("text,len_a,len_b,len_c", + [ + ("選挙管理委員会", 4, 3, 1), + ("客室乗務員", 3, 2, 1), + ("労働者協同組合", 4, 3, 1), + ("機能性食品", 3, 2, 1), + ] +) +def test_ja_tokenizer_split_modes(ja_tokenizer, text, len_a, len_b, len_c): + nlp_a = Japanese(meta={"tokenizer": {"config": {"split_mode": "A"}}}) + nlp_b = Japanese(meta={"tokenizer": {"config": {"split_mode": "B"}}}) + nlp_c = Japanese(meta={"tokenizer": {"config": {"split_mode": "C"}}}) + + assert len(ja_tokenizer(text)) == len_a + assert len(nlp_a(text)) == len_a + assert len(nlp_b(text)) == len_b + assert len(nlp_c(text)) == len_c