mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
Patch serialization bug raised in #1105
This commit is contained in:
parent
f0f2739ae3
commit
8978212ee5
|
@ -157,11 +157,13 @@ class BaseThincComponent(object):
|
||||||
|
|
||||||
def to_bytes(self, **exclude):
|
def to_bytes(self, **exclude):
|
||||||
"""Serialize the pipe to a bytestring."""
|
"""Serialize the pipe to a bytestring."""
|
||||||
serialize = OrderedDict((
|
serialize = OrderedDict()
|
||||||
('cfg', lambda: json_dumps(self.cfg)),
|
serialize['cfg'] = lambda: json_dumps(self.cfg)
|
||||||
('model', lambda: self.model.to_bytes()),
|
if self.model in (True, False, None):
|
||||||
('vocab', lambda: self.vocab.to_bytes())
|
serialize['model'] = lambda: self.model
|
||||||
))
|
else:
|
||||||
|
serialize['model'] = self.model.to_bytes
|
||||||
|
serialize['vocab'] = self.vocab.to_bytes
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, **exclude):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
|
@ -182,11 +184,11 @@ class BaseThincComponent(object):
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
"""Serialize the pipe to disk."""
|
"""Serialize the pipe to disk."""
|
||||||
serialize = OrderedDict((
|
serialize = OrderedDict()
|
||||||
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))),
|
serialize['cfg'] = lambda p: p.open('w').write(json_dumps(self.cfg))
|
||||||
('vocab', lambda p: self.vocab.to_disk(p)),
|
serialize['vocab'] = lambda p: self.vocab.to_disk(p)
|
||||||
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
if self.model not in (None, True, False):
|
||||||
))
|
serialize['model'] = lambda p: p.open('wb').write(self.model.to_bytes())
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
def from_disk(self, path, **exclude):
|
def from_disk(self, path, **exclude):
|
||||||
|
@ -437,13 +439,16 @@ class NeuralTagger(BaseThincComponent):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def to_bytes(self, **exclude):
|
def to_bytes(self, **exclude):
|
||||||
serialize = OrderedDict((
|
serialize = OrderedDict()
|
||||||
('model', lambda: self.model.to_bytes()),
|
if self.model in (None, True, False):
|
||||||
('vocab', lambda: self.vocab.to_bytes()),
|
serialize['model'] = lambda: self.model
|
||||||
('tag_map', lambda: msgpack.dumps(self.vocab.morphology.tag_map,
|
else:
|
||||||
use_bin_type=True,
|
serialize['model'] = self.model.to_bytes
|
||||||
encoding='utf8'))
|
serialize['vocab'] = self.vocab.to_bytes
|
||||||
))
|
|
||||||
|
serialize['tag_map'] = lambda: msgpack.dumps(self.vocab.morphology.tag_map,
|
||||||
|
use_bin_type=True,
|
||||||
|
encoding='utf8')
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, **exclude):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user