mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Fix serialisation, for reals this time
This commit is contained in:
parent
35d981241f
commit
2a061e2777
|
@ -12,6 +12,7 @@ import textwrap
|
||||||
import random
|
import random
|
||||||
import numpy
|
import numpy
|
||||||
import io
|
import io
|
||||||
|
import dill
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import msgpack_numpy
|
import msgpack_numpy
|
||||||
|
@ -422,45 +423,53 @@ def from_bytes(bytes_data, setters, exclude):
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
# This stuff really belongs in thinc -- but I expect
|
||||||
|
# to refactor how all this works in thinc anyway.
|
||||||
|
# What a mess!
|
||||||
def model_to_bytes(model):
|
def model_to_bytes(model):
|
||||||
weights = []
|
weights = []
|
||||||
metas = []
|
|
||||||
dims = []
|
|
||||||
queue = [model]
|
queue = [model]
|
||||||
i = 0
|
i = 0
|
||||||
for layer in queue:
|
for layer in queue:
|
||||||
if hasattr(layer, '_mem'):
|
if hasattr(layer, '_mem'):
|
||||||
if isinstance(layer._mem.weights, numpy.ndarray):
|
weights.append({'dims': dict(getattr(layer, '_dims', {})), 'params': []})
|
||||||
weights.append(layer._mem.weights)
|
if hasattr(layer, 'seed'):
|
||||||
else:
|
weights[-1]['seed'] = layer.seed
|
||||||
weights.append(layer._mem.weights.get())
|
|
||||||
metas.append(layer._mem._offsets)
|
for (id_, name), (start, row, shape) in layer._mem._offsets.items():
|
||||||
dims.append(getattr(layer, '_dims', None))
|
if row == 1:
|
||||||
|
continue
|
||||||
|
param = layer._mem.get((id_, name))
|
||||||
|
if not isinstance(layer._mem.weights, numpy.ndarray):
|
||||||
|
param = param.get()
|
||||||
|
weights[-1]['params'].append(
|
||||||
|
{
|
||||||
|
'name': name,
|
||||||
|
'offset': start,
|
||||||
|
'shape': shape,
|
||||||
|
'value': param,
|
||||||
|
}
|
||||||
|
)
|
||||||
i += 1
|
i += 1
|
||||||
if hasattr(layer, '_layers'):
|
if hasattr(layer, '_layers'):
|
||||||
queue.extend(layer._layers)
|
queue.extend(layer._layers)
|
||||||
data = {'metas': ujson.dumps(metas), 'weights': weights, 'dims': ujson.dumps(dims)}
|
return msgpack.dumps({'weights': weights})
|
||||||
return msgpack.dumps(data)
|
|
||||||
|
|
||||||
|
|
||||||
def model_from_bytes(model, bytes_data):
|
def model_from_bytes(model, bytes_data):
|
||||||
data = msgpack.loads(bytes_data)
|
data = msgpack.loads(bytes_data)
|
||||||
weights = data['weights']
|
weights = data['weights']
|
||||||
metas = ujson.loads(data['metas'])
|
|
||||||
dims = ujson.loads(data['dims'])
|
|
||||||
queue = [model]
|
queue = [model]
|
||||||
i = 0
|
i = 0
|
||||||
for layer in queue:
|
for layer in queue:
|
||||||
if hasattr(layer, '_mem'):
|
if hasattr(layer, '_mem'):
|
||||||
params = weights[i]
|
if 'seed' in weights[i]:
|
||||||
layer._mem._get_blob(params.size)
|
layer.seed = weights[i]['seed']
|
||||||
layer._mem._i -= params.size
|
for dim, value in weights[i]['dims'].items():
|
||||||
flat_mem = layer._mem._mem.ravel()
|
setattr(layer, dim, value)
|
||||||
flat_params = params.ravel()
|
for param in weights[i]['params']:
|
||||||
flat_mem[:flat_params.size] = flat_params
|
dest = getattr(layer, param['name'])
|
||||||
layer._mem._offsets.update(metas[i])
|
dest[:] = param['value']
|
||||||
if hasattr(layer, '_dims'):
|
|
||||||
layer._dims.update(dims[i])
|
|
||||||
i += 1
|
i += 1
|
||||||
if hasattr(layer, '_layers'):
|
if hasattr(layer, '_layers'):
|
||||||
queue.extend(layer._layers)
|
queue.extend(layer._layers)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user