diff --git a/spacy/language.py b/spacy/language.py index bebdeab20..9f8cc49e1 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -5,7 +5,7 @@ import pathlib from contextlib import contextmanager import shutil -import ujson as json +import ujson try: @@ -13,6 +13,10 @@ try: except NameError: basestring = str +try: + unicode +except NameError: + unicode = str from .tokenizer import Tokenizer from .vocab import Vocab @@ -226,12 +230,21 @@ class Language(object): parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples) entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples) - with (dep_model_dir / 'config.json').open('w') as file_: - json.dump(parser_cfg, file_) - with (ner_model_dir / 'config.json').open('w') as file_: - json.dump(entity_cfg, file_) - with (pos_model_dir / 'config.json').open('w') as file_: - json.dump(tagger_cfg, file_) + with (dep_model_dir / 'config.json').open('wb') as file_: + data = ujson.dumps(parser_cfg) + if isinstance(data, unicode): + data = data.encode('utf8') + file_.write(data) + with (ner_model_dir / 'config.json').open('wb') as file_: + data = ujson.dumps(entity_cfg) + if isinstance(data, unicode): + data = data.encode('utf8') + file_.write(data) + with (pos_model_dir / 'config.json').open('wb') as file_: + data = ujson.dumps(tagger_cfg) + if isinstance(data, unicode): + data = data.encode('utf8') + file_.write(data) self = cls( path=path, @@ -391,12 +404,14 @@ class Language(object): else: entity_iob_freqs = [] entity_type_freqs = [] - with (path / 'vocab' / 'serializer.json').open('w') as file_: - file_.write( - json.dumps([ - (TAG, tagger_freqs), - (DEP, dep_freqs), - (ENT_IOB, entity_iob_freqs), - (ENT_TYPE, entity_type_freqs), - (HEAD, head_freqs) - ])) + with (path / 'vocab' / 'serializer.json').open('wb') as file_: + data = ujson.dumps([ + (TAG, tagger_freqs), + (DEP, dep_freqs), + (ENT_IOB, entity_iob_freqs), + (ENT_TYPE, entity_type_freqs), + (HEAD, head_freqs) + ]) + if isinstance(data, unicode): + data = data.encode('utf8') + file_.write(data)