From 1fa2bfb600c9617eb24b4d9269e08e38f43e0401 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 29 May 2017 09:27:04 +0200 Subject: [PATCH] Add model_to_bytes and model_from_bytes helpers. Probably belong in thinc. --- spacy/_ml.py | 20 +++++++++++++------- spacy/tests/test_misc.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/spacy/_ml.py b/spacy/_ml.py index 3c2f4ccc7..b09e2ef95 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -16,33 +16,37 @@ from thinc.neural._classes.affine import _set_dimensions_if_needed from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP from .tokens.doc import Doc -import dill import numpy import io +import msgpack +import msgpack_numpy +msgpack_numpy.patch() def model_to_bytes(model): weights = [] metas = [] + dims = [] queue = [model] i = 0 for layer in queue: if hasattr(layer, '_mem'): weights.append(layer._mem.weights) - metas.append(layer._mem._offsets) + metas.append(tuple(layer._mem._offsets)) + dims.append(getattr(layer, '_dims', None)) i += 1 if hasattr(layer, '_layers'): queue.extend(layer._layers) - data = {'metas': metas, 'weights': weights} - # TODO: Replace the pickle here with something else - return dill.dumps(data) + data = {'metas': tuple(metas), 'weights': tuple(weights), 'dims': + tuple(dims)} + return msgpack.dumps(data) def model_from_bytes(model, bytes_data): - # TODO: Replace the pickle here with something else - data = dill.loads(bytes_data) + data = msgpack.loads(bytes_data) metas = data['metas'] weights = data['weights'] + dims = data['dims'] queue = [model] i = 0 for layer in queue: @@ -52,6 +56,8 @@ def model_from_bytes(model, bytes_data): flat_params = params.ravel() flat_mem[:flat_params.size] = flat_params layer._mem._offsets.update(metas[i]) + if hasattr(layer, '_dims'): + layer._dims.update(dims[i]) i += 1 if hasattr(layer, '_layers'): queue.extend(layer._layers) diff --git a/spacy/tests/test_misc.py b/spacy/tests/test_misc.py index 27c8d9f62..404422289 100644 --- a/spacy/tests/test_misc.py +++ b/spacy/tests/test_misc.py @@ -37,3 +37,13 @@ def test_multi_model_roundtrip_bytes(): assert model._layers[1].b[0, 0] == 2 +def test_multi_model_load_missing_dims(): + model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3)) + model._layers[0].b += 1 + model._layers[1].b += 2 + data = model_to_bytes(model) + + model2 = chain(Maxout(5), Maxout()) + model_from_bytes(model2, data) + assert model2._layers[0].b[0, 0] == 1 + assert model2._layers[1].b[0, 0] == 2