Fix for serialization

This commit is contained in:
Matthew Honnibal 2017-05-29 13:47:42 +02:00
parent 2e364f7ecd
commit a1960c2d09

View File

@ -437,7 +437,10 @@ def model_to_bytes(model):
i = 0
for layer in queue:
if hasattr(layer, '_mem'):
weights.append(layer._mem.weights)
if layer._mem.weights.size:
weights.append(layer._mem.weights)
else:
weights.append(None)
metas.append(tuple(layer._mem._offsets))
dims.append(getattr(layer, '_dims', None))
i += 1
@ -458,9 +461,10 @@ def model_from_bytes(model, bytes_data):
for layer in queue:
if hasattr(layer, '_mem'):
params = weights[i]
flat_mem = layer._mem._mem.ravel()
flat_params = params.ravel()
flat_mem[:flat_params.size] = flat_params
if params is not None:
flat_mem = layer._mem._mem.ravel()
flat_params = params.ravel()
flat_mem[:flat_params.size] = flat_params
layer._mem._offsets.update(metas[i])
if hasattr(layer, '_dims'):
layer._dims.update(dims[i])