Add model_to_bytes and model_from_bytes helpers. Probably belong in thinc.

This commit is contained in:
Matthew Honnibal 2017-05-29 09:27:04 +02:00
parent 6dad4117ad
commit 1fa2bfb600
2 changed files with 23 additions and 7 deletions

View File

@ -16,33 +16,37 @@ from thinc.neural._classes.affine import _set_dimensions_if_needed
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
from .tokens.doc import Doc
import dill
import numpy
import io
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
def model_to_bytes(model):
weights = []
metas = []
dims = []
queue = [model]
i = 0
for layer in queue:
if hasattr(layer, '_mem'):
weights.append(layer._mem.weights)
metas.append(layer._mem._offsets)
metas.append(tuple(layer._mem._offsets))
dims.append(getattr(layer, '_dims', None))
i += 1
if hasattr(layer, '_layers'):
queue.extend(layer._layers)
data = {'metas': metas, 'weights': weights}
# TODO: Replace the pickle here with something else
return dill.dumps(data)
data = {'metas': tuple(metas), 'weights': tuple(weights), 'dims':
tuple(dims)}
return msgpack.dumps(data)
def model_from_bytes(model, bytes_data):
# TODO: Replace the pickle here with something else
data = dill.loads(bytes_data)
data = msgpack.loads(bytes_data)
metas = data['metas']
weights = data['weights']
dims = data['dims']
queue = [model]
i = 0
for layer in queue:
@ -52,6 +56,8 @@ def model_from_bytes(model, bytes_data):
flat_params = params.ravel()
flat_mem[:flat_params.size] = flat_params
layer._mem._offsets.update(metas[i])
if hasattr(layer, '_dims'):
layer._dims.update(dims[i])
i += 1
if hasattr(layer, '_layers'):
queue.extend(layer._layers)

View File

@ -37,3 +37,13 @@ def test_multi_model_roundtrip_bytes():
assert model._layers[1].b[0, 0] == 2
def test_multi_model_load_missing_dims():
model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3))
model._layers[0].b += 1
model._layers[1].b += 2
data = model_to_bytes(model)
model2 = chain(Maxout(5), Maxout())
model_from_bytes(model2, data)
assert model2._layers[0].b[0, 0] == 1
assert model2._layers[1].b[0, 0] == 2