mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-04 03:43:09 +03:00
Move serialization functions to util
This commit is contained in:
parent
1fa2bfb600
commit
c91b121aeb
43
spacy/_ml.py
43
spacy/_ml.py
|
@ -18,49 +18,6 @@ from .tokens.doc import Doc
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import io
|
import io
|
||||||
import msgpack
|
|
||||||
import msgpack_numpy
|
|
||||||
msgpack_numpy.patch()
|
|
||||||
|
|
||||||
|
|
||||||
def model_to_bytes(model):
|
|
||||||
weights = []
|
|
||||||
metas = []
|
|
||||||
dims = []
|
|
||||||
queue = [model]
|
|
||||||
i = 0
|
|
||||||
for layer in queue:
|
|
||||||
if hasattr(layer, '_mem'):
|
|
||||||
weights.append(layer._mem.weights)
|
|
||||||
metas.append(tuple(layer._mem._offsets))
|
|
||||||
dims.append(getattr(layer, '_dims', None))
|
|
||||||
i += 1
|
|
||||||
if hasattr(layer, '_layers'):
|
|
||||||
queue.extend(layer._layers)
|
|
||||||
data = {'metas': tuple(metas), 'weights': tuple(weights), 'dims':
|
|
||||||
tuple(dims)}
|
|
||||||
return msgpack.dumps(data)
|
|
||||||
|
|
||||||
|
|
||||||
def model_from_bytes(model, bytes_data):
|
|
||||||
data = msgpack.loads(bytes_data)
|
|
||||||
metas = data['metas']
|
|
||||||
weights = data['weights']
|
|
||||||
dims = data['dims']
|
|
||||||
queue = [model]
|
|
||||||
i = 0
|
|
||||||
for layer in queue:
|
|
||||||
if hasattr(layer, '_mem'):
|
|
||||||
params = weights[i]
|
|
||||||
flat_mem = layer._mem._mem.ravel()
|
|
||||||
flat_params = params.ravel()
|
|
||||||
flat_mem[:flat_params.size] = flat_params
|
|
||||||
layer._mem._offsets.update(metas[i])
|
|
||||||
if hasattr(layer, '_dims'):
|
|
||||||
layer._dims.update(dims[i])
|
|
||||||
i += 1
|
|
||||||
if hasattr(layer, '_layers'):
|
|
||||||
queue.extend(layer._layers)
|
|
||||||
|
|
||||||
|
|
||||||
def _init_for_precomputed(W, ops):
|
def _init_for_precomputed(W, ops):
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from ..util import ensure_path
|
from ..util import ensure_path
|
||||||
from .._ml import model_to_bytes, model_from_bytes
|
from ..util import model_to_bytes, model_from_bytes
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
|
|
|
@ -11,6 +11,10 @@ import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import msgpack_numpy
|
||||||
|
msgpack_numpy.patch()
|
||||||
|
|
||||||
from .symbols import ORTH
|
from .symbols import ORTH
|
||||||
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
|
||||||
|
|
||||||
|
@ -408,16 +412,60 @@ def get_raw_input(description, default=False):
|
||||||
return user_input
|
return user_input
|
||||||
|
|
||||||
|
|
||||||
def to_bytes(unserialized, exclude):
|
def to_bytes(getters, exclude):
|
||||||
serialized = {}
|
serialized = {}
|
||||||
for key, value in unserialized.items():
|
for key, getter in getters.items():
|
||||||
if key in exclude:
|
if key not in exclude:
|
||||||
continue
|
serialized[key] = getter()
|
||||||
elif hasattr(value, 'to_bytes'):
|
return messagepack.dumps(serialized)
|
||||||
serialized[key] = value.to_bytes()
|
|
||||||
else:
|
|
||||||
serialized[key] = ujson.dumps(value)
|
def from_bytes(bytes_data, setters, exclude):
|
||||||
return ujson.dumps(serialized)
|
msg = messagepack.loads(bytes_data)
|
||||||
|
for key, setter in setters.items():
|
||||||
|
if key not in exclude:
|
||||||
|
setter(msg[key])
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def model_to_bytes(model):
|
||||||
|
weights = []
|
||||||
|
metas = []
|
||||||
|
dims = []
|
||||||
|
queue = [model]
|
||||||
|
i = 0
|
||||||
|
for layer in queue:
|
||||||
|
if hasattr(layer, '_mem'):
|
||||||
|
weights.append(layer._mem.weights)
|
||||||
|
metas.append(tuple(layer._mem._offsets))
|
||||||
|
dims.append(getattr(layer, '_dims', None))
|
||||||
|
i += 1
|
||||||
|
if hasattr(layer, '_layers'):
|
||||||
|
queue.extend(layer._layers)
|
||||||
|
data = {'metas': tuple(metas), 'weights': tuple(weights), 'dims':
|
||||||
|
tuple(dims)}
|
||||||
|
return msgpack.dumps(data)
|
||||||
|
|
||||||
|
|
||||||
|
def model_from_bytes(model, bytes_data):
|
||||||
|
data = msgpack.loads(bytes_data)
|
||||||
|
metas = data['metas']
|
||||||
|
weights = data['weights']
|
||||||
|
dims = data['dims']
|
||||||
|
queue = [model]
|
||||||
|
i = 0
|
||||||
|
for layer in queue:
|
||||||
|
if hasattr(layer, '_mem'):
|
||||||
|
params = weights[i]
|
||||||
|
flat_mem = layer._mem._mem.ravel()
|
||||||
|
flat_params = params.ravel()
|
||||||
|
flat_mem[:flat_params.size] = flat_params
|
||||||
|
layer._mem._offsets.update(metas[i])
|
||||||
|
if hasattr(layer, '_dims'):
|
||||||
|
layer._dims.update(dims[i])
|
||||||
|
i += 1
|
||||||
|
if hasattr(layer, '_layers'):
|
||||||
|
queue.extend(layer._layers)
|
||||||
|
|
||||||
|
|
||||||
def print_table(data, title=None):
|
def print_table(data, title=None):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user