Fix serialisation, for reals this time

This commit is contained in:
Matthew Honnibal 2017-05-29 17:52:08 -05:00
parent 35d981241f
commit 2a061e2777

View File

@ -12,6 +12,7 @@ import textwrap
import random import random
import numpy import numpy
import io import io
import dill
import msgpack import msgpack
import msgpack_numpy import msgpack_numpy
@ -422,45 +423,53 @@ def from_bytes(bytes_data, setters, exclude):
return msg return msg
# This stuff really belongs in thinc -- but I expect
# to refactor how all this works in thinc anyway.
# What a mess!
def model_to_bytes(model): def model_to_bytes(model):
weights = [] weights = []
metas = []
dims = []
queue = [model] queue = [model]
i = 0 i = 0
for layer in queue: for layer in queue:
if hasattr(layer, '_mem'): if hasattr(layer, '_mem'):
if isinstance(layer._mem.weights, numpy.ndarray): weights.append({'dims': dict(getattr(layer, '_dims', {})), 'params': []})
weights.append(layer._mem.weights) if hasattr(layer, 'seed'):
else: weights[-1]['seed'] = layer.seed
weights.append(layer._mem.weights.get())
metas.append(layer._mem._offsets) for (id_, name), (start, row, shape) in layer._mem._offsets.items():
dims.append(getattr(layer, '_dims', None)) if row == 1:
continue
param = layer._mem.get((id_, name))
if not isinstance(layer._mem.weights, numpy.ndarray):
param = param.get()
weights[-1]['params'].append(
{
'name': name,
'offset': start,
'shape': shape,
'value': param,
}
)
i += 1 i += 1
if hasattr(layer, '_layers'): if hasattr(layer, '_layers'):
queue.extend(layer._layers) queue.extend(layer._layers)
data = {'metas': ujson.dumps(metas), 'weights': weights, 'dims': ujson.dumps(dims)} return msgpack.dumps({'weights': weights})
return msgpack.dumps(data)
def model_from_bytes(model, bytes_data): def model_from_bytes(model, bytes_data):
data = msgpack.loads(bytes_data) data = msgpack.loads(bytes_data)
weights = data['weights'] weights = data['weights']
metas = ujson.loads(data['metas'])
dims = ujson.loads(data['dims'])
queue = [model] queue = [model]
i = 0 i = 0
for layer in queue: for layer in queue:
if hasattr(layer, '_mem'): if hasattr(layer, '_mem'):
params = weights[i] if 'seed' in weights[i]:
layer._mem._get_blob(params.size) layer.seed = weights[i]['seed']
layer._mem._i -= params.size for dim, value in weights[i]['dims'].items():
flat_mem = layer._mem._mem.ravel() setattr(layer, dim, value)
flat_params = params.ravel() for param in weights[i]['params']:
flat_mem[:flat_params.size] = flat_params dest = getattr(layer, param['name'])
layer._mem._offsets.update(metas[i]) dest[:] = param['value']
if hasattr(layer, '_dims'):
layer._dims.update(dims[i])
i += 1 i += 1
if hasattr(layer, '_layers'): if hasattr(layer, '_layers'):
queue.extend(layer._layers) queue.extend(layer._layers)