mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-09 16:10:33 +03:00
Add model_to_bytes and model_from_bytes helpers. Probably belong in thinc.
This commit is contained in:
parent
6dad4117ad
commit
1fa2bfb600
20
spacy/_ml.py
20
spacy/_ml.py
|
@ -16,33 +16,37 @@ from thinc.neural._classes.affine import _set_dimensions_if_needed
|
||||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||||
from .tokens.doc import Doc
|
from .tokens.doc import Doc
|
||||||
|
|
||||||
import dill
|
|
||||||
import numpy
|
import numpy
|
||||||
import io
|
import io
|
||||||
|
import msgpack
|
||||||
|
import msgpack_numpy
|
||||||
|
msgpack_numpy.patch()
|
||||||
|
|
||||||
|
|
||||||
def model_to_bytes(model):
|
def model_to_bytes(model):
|
||||||
weights = []
|
weights = []
|
||||||
metas = []
|
metas = []
|
||||||
|
dims = []
|
||||||
queue = [model]
|
queue = [model]
|
||||||
i = 0
|
i = 0
|
||||||
for layer in queue:
|
for layer in queue:
|
||||||
if hasattr(layer, '_mem'):
|
if hasattr(layer, '_mem'):
|
||||||
weights.append(layer._mem.weights)
|
weights.append(layer._mem.weights)
|
||||||
metas.append(layer._mem._offsets)
|
metas.append(tuple(layer._mem._offsets))
|
||||||
|
dims.append(getattr(layer, '_dims', None))
|
||||||
i += 1
|
i += 1
|
||||||
if hasattr(layer, '_layers'):
|
if hasattr(layer, '_layers'):
|
||||||
queue.extend(layer._layers)
|
queue.extend(layer._layers)
|
||||||
data = {'metas': metas, 'weights': weights}
|
data = {'metas': tuple(metas), 'weights': tuple(weights), 'dims':
|
||||||
# TODO: Replace the pickle here with something else
|
tuple(dims)}
|
||||||
return dill.dumps(data)
|
return msgpack.dumps(data)
|
||||||
|
|
||||||
|
|
||||||
def model_from_bytes(model, bytes_data):
|
def model_from_bytes(model, bytes_data):
|
||||||
# TODO: Replace the pickle here with something else
|
data = msgpack.loads(bytes_data)
|
||||||
data = dill.loads(bytes_data)
|
|
||||||
metas = data['metas']
|
metas = data['metas']
|
||||||
weights = data['weights']
|
weights = data['weights']
|
||||||
|
dims = data['dims']
|
||||||
queue = [model]
|
queue = [model]
|
||||||
i = 0
|
i = 0
|
||||||
for layer in queue:
|
for layer in queue:
|
||||||
|
@ -52,6 +56,8 @@ def model_from_bytes(model, bytes_data):
|
||||||
flat_params = params.ravel()
|
flat_params = params.ravel()
|
||||||
flat_mem[:flat_params.size] = flat_params
|
flat_mem[:flat_params.size] = flat_params
|
||||||
layer._mem._offsets.update(metas[i])
|
layer._mem._offsets.update(metas[i])
|
||||||
|
if hasattr(layer, '_dims'):
|
||||||
|
layer._dims.update(dims[i])
|
||||||
i += 1
|
i += 1
|
||||||
if hasattr(layer, '_layers'):
|
if hasattr(layer, '_layers'):
|
||||||
queue.extend(layer._layers)
|
queue.extend(layer._layers)
|
||||||
|
|
|
@ -37,3 +37,13 @@ def test_multi_model_roundtrip_bytes():
|
||||||
assert model._layers[1].b[0, 0] == 2
|
assert model._layers[1].b[0, 0] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_model_load_missing_dims():
|
||||||
|
model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3))
|
||||||
|
model._layers[0].b += 1
|
||||||
|
model._layers[1].b += 2
|
||||||
|
data = model_to_bytes(model)
|
||||||
|
|
||||||
|
model2 = chain(Maxout(5), Maxout())
|
||||||
|
model_from_bytes(model2, data)
|
||||||
|
assert model2._layers[0].b[0, 0] == 1
|
||||||
|
assert model2._layers[1].b[0, 0] == 2
|
||||||
|
|
Loading…
Reference in New Issue
Block a user