Move serialization functions to util

This commit is contained in:
Matthew Honnibal 2017-05-29 10:13:42 +02:00
parent 1fa2bfb600
commit c91b121aeb
3 changed files with 58 additions and 53 deletions

View File

@ -18,51 +18,8 @@ from .tokens.doc import Doc
import numpy
import io
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
def model_to_bytes(model):
weights = []
metas = []
dims = []
queue = [model]
i = 0
for layer in queue:
if hasattr(layer, '_mem'):
weights.append(layer._mem.weights)
metas.append(tuple(layer._mem._offsets))
dims.append(getattr(layer, '_dims', None))
i += 1
if hasattr(layer, '_layers'):
queue.extend(layer._layers)
data = {'metas': tuple(metas), 'weights': tuple(weights), 'dims':
tuple(dims)}
return msgpack.dumps(data)
def model_from_bytes(model, bytes_data):
data = msgpack.loads(bytes_data)
metas = data['metas']
weights = data['weights']
dims = data['dims']
queue = [model]
i = 0
for layer in queue:
if hasattr(layer, '_mem'):
params = weights[i]
flat_mem = layer._mem._mem.ravel()
flat_params = params.ravel()
flat_mem[:flat_params.size] = flat_params
layer._mem._offsets.update(metas[i])
if hasattr(layer, '_dims'):
layer._dims.update(dims[i])
i += 1
if hasattr(layer, '_layers'):
queue.extend(layer._layers)
def _init_for_precomputed(W, ops):
if (W**2).sum() != 0.:
return

View File

@ -2,7 +2,7 @@
from __future__ import unicode_literals
from ..util import ensure_path
from .._ml import model_to_bytes, model_from_bytes
from ..util import model_to_bytes, model_from_bytes
from pathlib import Path
import pytest

View File

@ -11,6 +11,10 @@ import sys
import textwrap
import random
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
@ -408,18 +412,62 @@ def get_raw_input(description, default=False):
return user_input
def to_bytes(unserialized, exclude):
def to_bytes(getters, exclude):
serialized = {}
for key, value in unserialized.items():
if key in exclude:
continue
elif hasattr(value, 'to_bytes'):
serialized[key] = value.to_bytes()
else:
serialized[key] = ujson.dumps(value)
return ujson.dumps(serialized)
for key, getter in getters.items():
if key not in exclude:
serialized[key] = getter()
return messagepack.dumps(serialized)
def from_bytes(bytes_data, setters, exclude):
msg = messagepack.loads(bytes_data)
for key, setter in setters.items():
if key not in exclude:
setter(msg[key])
return msg
def model_to_bytes(model):
weights = []
metas = []
dims = []
queue = [model]
i = 0
for layer in queue:
if hasattr(layer, '_mem'):
weights.append(layer._mem.weights)
metas.append(tuple(layer._mem._offsets))
dims.append(getattr(layer, '_dims', None))
i += 1
if hasattr(layer, '_layers'):
queue.extend(layer._layers)
data = {'metas': tuple(metas), 'weights': tuple(weights), 'dims':
tuple(dims)}
return msgpack.dumps(data)
def model_from_bytes(model, bytes_data):
data = msgpack.loads(bytes_data)
metas = data['metas']
weights = data['weights']
dims = data['dims']
queue = [model]
i = 0
for layer in queue:
if hasattr(layer, '_mem'):
params = weights[i]
flat_mem = layer._mem._mem.ravel()
flat_params = params.ravel()
flat_mem[:flat_params.size] = flat_params
layer._mem._offsets.update(metas[i])
if hasattr(layer, '_dims'):
layer._dims.update(dims[i])
i += 1
if hasattr(layer, '_layers'):
queue.extend(layer._layers)
def print_table(data, title=None):
"""Print data in table format.