Fix encoding on tagger serialization

This commit is contained in:
ines 2017-06-02 17:29:21 +02:00
parent 69d58dbc89
commit 1b593bbd6d

View File

@ -303,7 +303,9 @@ class NeuralTagger(object):
serialize = OrderedDict(( serialize = OrderedDict((
('model', lambda: self.model.to_bytes()), ('model', lambda: self.model.to_bytes()),
('vocab', lambda: self.vocab.to_bytes()), ('vocab', lambda: self.vocab.to_bytes()),
('tag_map', lambda: msgpack.dumps(self.vocab.morphology.tag_map)) ('tag_map', lambda: msgpack.dumps(self.vocab.morphology.tag_map,
use_bin_type=True,
encoding='utf8'))
)) ))
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
@ -315,7 +317,7 @@ class NeuralTagger(object):
self.model.from_bytes(b) self.model.from_bytes(b)
def load_tag_map(b): def load_tag_map(b):
tag_map = msgpack.loads(b) tag_map = msgpack.loads(b, encoding='utf8')
self.vocab.morphology = Morphology( self.vocab.morphology = Morphology(
self.vocab.strings, tag_map=tag_map, self.vocab.strings, tag_map=tag_map,
lemmatizer=self.vocab.morphology.lemmatizer) lemmatizer=self.vocab.morphology.lemmatizer)
@ -330,8 +332,10 @@ class NeuralTagger(object):
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
serialize = OrderedDict(( serialize = OrderedDict((
('vocab', lambda p: self.vocab.to_disk(p)), ('vocab', lambda p: self.vocab.to_disk(p)),
('tag_map', lambda p: p.open('w').write(msgpack.dumps( ('tag_map', lambda p: p.open('wb').write(msgpack.dumps(
self.vocab.morphology.tag_map))), self.vocab.morphology.tag_map,
use_bin_type=True,
encoding='utf8'))),
('model', lambda p: p.open('wb').write(self.model.to_bytes())), ('model', lambda p: p.open('wb').write(self.model.to_bytes())),
)) ))
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
@ -344,8 +348,8 @@ class NeuralTagger(object):
self.model.from_bytes(p.open('rb').read()) self.model.from_bytes(p.open('rb').read())
def load_tag_map(p): def load_tag_map(p):
with p.open() as file_: with p.open('rb') as file_:
tag_map = msgpack.loads(file_.read()) tag_map = msgpack.loads(file_.read(), encoding='utf8')
self.vocab.morphology = Morphology( self.vocab.morphology = Morphology(
self.vocab.strings, tag_map=tag_map, self.vocab.strings, tag_map=tag_map,
lemmatizer=self.vocab.morphology.lemmatizer) lemmatizer=self.vocab.morphology.lemmatizer)