Patch serialization bug raised in #1105

This commit is contained in:
Matthew Honnibal 2017-10-10 03:58:12 +02:00
parent f0f2739ae3
commit 8978212ee5

View File

@ -157,11 +157,13 @@ class BaseThincComponent(object):
def to_bytes(self, **exclude):
"""Serialize the pipe to a bytestring."""
serialize = OrderedDict((
('cfg', lambda: json_dumps(self.cfg)),
('model', lambda: self.model.to_bytes()),
('vocab', lambda: self.vocab.to_bytes())
))
serialize = OrderedDict()
serialize['cfg'] = lambda: json_dumps(self.cfg)
if self.model in (True, False, None):
serialize['model'] = lambda: self.model
else:
serialize['model'] = self.model.to_bytes
serialize['vocab'] = self.vocab.to_bytes
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, **exclude):
@ -182,11 +184,11 @@ class BaseThincComponent(object):
def to_disk(self, path, **exclude):
"""Serialize the pipe to disk."""
serialize = OrderedDict((
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))),
('vocab', lambda p: self.vocab.to_disk(p)),
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
))
serialize = OrderedDict()
serialize['cfg'] = lambda p: p.open('w').write(json_dumps(self.cfg))
serialize['vocab'] = lambda p: self.vocab.to_disk(p)
if self.model not in (None, True, False):
serialize['model'] = lambda p: p.open('wb').write(self.model.to_bytes())
util.to_disk(path, serialize, exclude)
def from_disk(self, path, **exclude):
@ -437,13 +439,16 @@ class NeuralTagger(BaseThincComponent):
yield
def to_bytes(self, **exclude):
serialize = OrderedDict((
('model', lambda: self.model.to_bytes()),
('vocab', lambda: self.vocab.to_bytes()),
('tag_map', lambda: msgpack.dumps(self.vocab.morphology.tag_map,
use_bin_type=True,
encoding='utf8'))
))
serialize = OrderedDict()
if self.model in (None, True, False):
serialize['model'] = lambda: self.model
else:
serialize['model'] = self.model.to_bytes
serialize['vocab'] = self.vocab.to_bytes
serialize['tag_map'] = lambda: msgpack.dumps(self.vocab.morphology.tag_map,
use_bin_type=True,
encoding='utf8')
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, **exclude):