mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 13:14:11 +03:00
Add model_to_bytes and model_from_bytes helpers. Probably belong in thinc.
This commit is contained in:
parent
6dad4117ad
commit
1fa2bfb600
20
spacy/_ml.py
20
spacy/_ml.py
|
@ -16,33 +16,37 @@ from thinc.neural._classes.affine import _set_dimensions_if_needed
|
|||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||
from .tokens.doc import Doc
|
||||
|
||||
import dill
|
||||
import numpy
|
||||
import io
|
||||
import msgpack
|
||||
import msgpack_numpy
|
||||
msgpack_numpy.patch()
|
||||
|
||||
|
||||
def model_to_bytes(model):
|
||||
weights = []
|
||||
metas = []
|
||||
dims = []
|
||||
queue = [model]
|
||||
i = 0
|
||||
for layer in queue:
|
||||
if hasattr(layer, '_mem'):
|
||||
weights.append(layer._mem.weights)
|
||||
metas.append(layer._mem._offsets)
|
||||
metas.append(tuple(layer._mem._offsets))
|
||||
dims.append(getattr(layer, '_dims', None))
|
||||
i += 1
|
||||
if hasattr(layer, '_layers'):
|
||||
queue.extend(layer._layers)
|
||||
data = {'metas': metas, 'weights': weights}
|
||||
# TODO: Replace the pickle here with something else
|
||||
return dill.dumps(data)
|
||||
data = {'metas': tuple(metas), 'weights': tuple(weights), 'dims':
|
||||
tuple(dims)}
|
||||
return msgpack.dumps(data)
|
||||
|
||||
|
||||
def model_from_bytes(model, bytes_data):
|
||||
# TODO: Replace the pickle here with something else
|
||||
data = dill.loads(bytes_data)
|
||||
data = msgpack.loads(bytes_data)
|
||||
metas = data['metas']
|
||||
weights = data['weights']
|
||||
dims = data['dims']
|
||||
queue = [model]
|
||||
i = 0
|
||||
for layer in queue:
|
||||
|
@ -52,6 +56,8 @@ def model_from_bytes(model, bytes_data):
|
|||
flat_params = params.ravel()
|
||||
flat_mem[:flat_params.size] = flat_params
|
||||
layer._mem._offsets.update(metas[i])
|
||||
if hasattr(layer, '_dims'):
|
||||
layer._dims.update(dims[i])
|
||||
i += 1
|
||||
if hasattr(layer, '_layers'):
|
||||
queue.extend(layer._layers)
|
||||
|
|
|
@ -37,3 +37,13 @@ def test_multi_model_roundtrip_bytes():
|
|||
assert model._layers[1].b[0, 0] == 2
|
||||
|
||||
|
||||
def test_multi_model_load_missing_dims():
|
||||
model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3))
|
||||
model._layers[0].b += 1
|
||||
model._layers[1].b += 2
|
||||
data = model_to_bytes(model)
|
||||
|
||||
model2 = chain(Maxout(5), Maxout())
|
||||
model_from_bytes(model2, data)
|
||||
assert model2._layers[0].b[0, 0] == 1
|
||||
assert model2._layers[1].b[0, 0] == 2
|
||||
|
|
Loading…
Reference in New Issue
Block a user