Add model_to_bytes and model_from_bytes helpers. Probably belong in thinc.

This commit is contained in:
Matthew Honnibal 2017-05-29 09:27:04 +02:00
parent 6dad4117ad
commit 1fa2bfb600
2 changed files with 23 additions and 7 deletions

View File

@ -16,33 +16,37 @@ from thinc.neural._classes.affine import _set_dimensions_if_needed
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
from .tokens.doc import Doc from .tokens.doc import Doc
import dill
import numpy import numpy
import io import io
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
def model_to_bytes(model): def model_to_bytes(model):
weights = [] weights = []
metas = [] metas = []
dims = []
queue = [model] queue = [model]
i = 0 i = 0
for layer in queue: for layer in queue:
if hasattr(layer, '_mem'): if hasattr(layer, '_mem'):
weights.append(layer._mem.weights) weights.append(layer._mem.weights)
metas.append(layer._mem._offsets) metas.append(tuple(layer._mem._offsets))
dims.append(getattr(layer, '_dims', None))
i += 1 i += 1
if hasattr(layer, '_layers'): if hasattr(layer, '_layers'):
queue.extend(layer._layers) queue.extend(layer._layers)
data = {'metas': metas, 'weights': weights} data = {'metas': tuple(metas), 'weights': tuple(weights), 'dims':
# TODO: Replace the pickle here with something else tuple(dims)}
return dill.dumps(data) return msgpack.dumps(data)
def model_from_bytes(model, bytes_data): def model_from_bytes(model, bytes_data):
# TODO: Replace the pickle here with something else data = msgpack.loads(bytes_data)
data = dill.loads(bytes_data)
metas = data['metas'] metas = data['metas']
weights = data['weights'] weights = data['weights']
dims = data['dims']
queue = [model] queue = [model]
i = 0 i = 0
for layer in queue: for layer in queue:
@ -52,6 +56,8 @@ def model_from_bytes(model, bytes_data):
flat_params = params.ravel() flat_params = params.ravel()
flat_mem[:flat_params.size] = flat_params flat_mem[:flat_params.size] = flat_params
layer._mem._offsets.update(metas[i]) layer._mem._offsets.update(metas[i])
if hasattr(layer, '_dims'):
layer._dims.update(dims[i])
i += 1 i += 1
if hasattr(layer, '_layers'): if hasattr(layer, '_layers'):
queue.extend(layer._layers) queue.extend(layer._layers)

View File

@ -37,3 +37,13 @@ def test_multi_model_roundtrip_bytes():
assert model._layers[1].b[0, 0] == 2 assert model._layers[1].b[0, 0] == 2
def test_multi_model_load_missing_dims():
model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3))
model._layers[0].b += 1
model._layers[1].b += 2
data = model_to_bytes(model)
model2 = chain(Maxout(5), Maxout())
model_from_bytes(model2, data)
assert model2._layers[0].b[0, 0] == 1
assert model2._layers[1].b[0, 0] == 2