diff --git a/spacy/_ml.py b/spacy/_ml.py index b09e2ef95..132bd55a2 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -18,51 +18,8 @@ from .tokens.doc import Doc import numpy import io -import msgpack -import msgpack_numpy -msgpack_numpy.patch() -def model_to_bytes(model): - weights = [] - metas = [] - dims = [] - queue = [model] - i = 0 - for layer in queue: - if hasattr(layer, '_mem'): - weights.append(layer._mem.weights) - metas.append(tuple(layer._mem._offsets)) - dims.append(getattr(layer, '_dims', None)) - i += 1 - if hasattr(layer, '_layers'): - queue.extend(layer._layers) - data = {'metas': tuple(metas), 'weights': tuple(weights), 'dims': - tuple(dims)} - return msgpack.dumps(data) - - -def model_from_bytes(model, bytes_data): - data = msgpack.loads(bytes_data) - metas = data['metas'] - weights = data['weights'] - dims = data['dims'] - queue = [model] - i = 0 - for layer in queue: - if hasattr(layer, '_mem'): - params = weights[i] - flat_mem = layer._mem._mem.ravel() - flat_params = params.ravel() - flat_mem[:flat_params.size] = flat_params - layer._mem._offsets.update(metas[i]) - if hasattr(layer, '_dims'): - layer._dims.update(dims[i]) - i += 1 - if hasattr(layer, '_layers'): - queue.extend(layer._layers) - - def _init_for_precomputed(W, ops): if (W**2).sum() != 0.: return diff --git a/spacy/tests/test_misc.py b/spacy/tests/test_misc.py index 404422289..2c0ff0520 100644 --- a/spacy/tests/test_misc.py +++ b/spacy/tests/test_misc.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals from ..util import ensure_path -from .._ml import model_to_bytes, model_from_bytes +from ..util import model_to_bytes, model_from_bytes from pathlib import Path import pytest diff --git a/spacy/util.py b/spacy/util.py index 5766d2db1..72dede705 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -11,6 +11,10 @@ import sys import textwrap import random +import msgpack +import msgpack_numpy +msgpack_numpy.patch() + from .symbols import ORTH from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_ @@ -408,18 +412,62 @@ def get_raw_input(description, default=False): return user_input -def to_bytes(unserialized, exclude): +def to_bytes(getters, exclude): serialized = {} - for key, value in unserialized.items(): - if key in exclude: - continue - elif hasattr(value, 'to_bytes'): - serialized[key] = value.to_bytes() - else: - serialized[key] = ujson.dumps(value) - return ujson.dumps(serialized) + for key, getter in getters.items(): + if key not in exclude: + serialized[key] = getter() + return messagepack.dumps(serialized) +def from_bytes(bytes_data, setters, exclude): + msg = messagepack.loads(bytes_data) + for key, setter in setters.items(): + if key not in exclude: + setter(msg[key]) + return msg + + +def model_to_bytes(model): + weights = [] + metas = [] + dims = [] + queue = [model] + i = 0 + for layer in queue: + if hasattr(layer, '_mem'): + weights.append(layer._mem.weights) + metas.append(tuple(layer._mem._offsets)) + dims.append(getattr(layer, '_dims', None)) + i += 1 + if hasattr(layer, '_layers'): + queue.extend(layer._layers) + data = {'metas': tuple(metas), 'weights': tuple(weights), 'dims': + tuple(dims)} + return msgpack.dumps(data) + + +def model_from_bytes(model, bytes_data): + data = msgpack.loads(bytes_data) + metas = data['metas'] + weights = data['weights'] + dims = data['dims'] + queue = [model] + i = 0 + for layer in queue: + if hasattr(layer, '_mem'): + params = weights[i] + flat_mem = layer._mem._mem.ravel() + flat_params = params.ravel() + flat_mem[:flat_params.size] = flat_params + layer._mem._offsets.update(metas[i]) + if hasattr(layer, '_dims'): + layer._dims.update(dims[i]) + i += 1 + if hasattr(layer, '_layers'): + queue.extend(layer._layers) + + def print_table(data, title=None): """Print data in table format.