mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Work on serialization for models
This commit is contained in:
parent
b007b0e5a0
commit
6dad4117ad
39
spacy/_ml.py
39
spacy/_ml.py
|
@ -1,3 +1,4 @@
|
|||
import ujson
|
||||
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||
from thinc.neural._classes.hash_embed import HashEmbed
|
||||
|
@ -15,7 +16,45 @@ from thinc.neural._classes.affine import _set_dimensions_if_needed
|
|||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||
from .tokens.doc import Doc
|
||||
|
||||
import dill
|
||||
import numpy
|
||||
import io
|
||||
|
||||
|
||||
def model_to_bytes(model):
|
||||
weights = []
|
||||
metas = []
|
||||
queue = [model]
|
||||
i = 0
|
||||
for layer in queue:
|
||||
if hasattr(layer, '_mem'):
|
||||
weights.append(layer._mem.weights)
|
||||
metas.append(layer._mem._offsets)
|
||||
i += 1
|
||||
if hasattr(layer, '_layers'):
|
||||
queue.extend(layer._layers)
|
||||
data = {'metas': metas, 'weights': weights}
|
||||
# TODO: Replace the pickle here with something else
|
||||
return dill.dumps(data)
|
||||
|
||||
|
||||
def model_from_bytes(model, bytes_data):
|
||||
# TODO: Replace the pickle here with something else
|
||||
data = dill.loads(bytes_data)
|
||||
metas = data['metas']
|
||||
weights = data['weights']
|
||||
queue = [model]
|
||||
i = 0
|
||||
for layer in queue:
|
||||
if hasattr(layer, '_mem'):
|
||||
params = weights[i]
|
||||
flat_mem = layer._mem._mem.ravel()
|
||||
flat_params = params.ravel()
|
||||
flat_mem[:flat_params.size] = flat_params
|
||||
layer._mem._offsets.update(metas[i])
|
||||
i += 1
|
||||
if hasattr(layer, '_layers'):
|
||||
queue.extend(layer._layers)
|
||||
|
||||
|
||||
def _init_for_precomputed(W, ops):
|
||||
|
|
|
@ -9,6 +9,7 @@ import numpy
|
|||
cimport numpy as np
|
||||
import cytoolz
|
||||
import util
|
||||
import ujson
|
||||
|
||||
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||
|
@ -35,6 +36,7 @@ from .syntax import nonproj
|
|||
|
||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
|
||||
from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats
|
||||
from ._ml import model_to_bytes, model_from_bytes
|
||||
from .parts_of_speech import X
|
||||
|
||||
|
||||
|
@ -148,7 +150,6 @@ class TokenVectorEncoder(object):
|
|||
if self.model is True:
|
||||
self.model = self.Model()
|
||||
|
||||
|
||||
def use_params(self, params):
|
||||
"""Replace weights of models in the pipeline with those provided in the
|
||||
params dictionary.
|
||||
|
@ -158,6 +159,39 @@ class TokenVectorEncoder(object):
|
|||
with self.model.use_params(params):
|
||||
yield
|
||||
|
||||
def to_bytes(self, **exclude):
|
||||
data = {
|
||||
'model': self.model,
|
||||
'vocab': self.vocab
|
||||
}
|
||||
return util.to_bytes(data, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
data = ujson.loads(bytes_data)
|
||||
if 'model' not in exclude:
|
||||
util.model_from_bytes(self.model, data['model'])
|
||||
if 'vocab' not in exclude:
|
||||
self.vocab.from_bytes(data['vocab'])
|
||||
return self
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
path = util.ensure_path(path)
|
||||
if not path.exists():
|
||||
path.mkdir()
|
||||
if 'vocab' not in exclude:
|
||||
self.vocab.to_disk(path / 'vocab')
|
||||
if 'model' not in exclude:
|
||||
with (path / 'model.bin').open('wb') as file_:
|
||||
file_.write(util.model_to_bytes(self.model))
|
||||
|
||||
def from_disk(self, path, **exclude):
|
||||
path = util.ensure_path(path)
|
||||
if 'vocab' not in exclude:
|
||||
self.vocab.from_disk(path / 'vocab')
|
||||
if 'model.bin' not in exclude:
|
||||
with (path / 'model.bin').open('rb') as file_:
|
||||
util.model_from_bytes(self.model, file_.read())
|
||||
|
||||
|
||||
class NeuralTagger(object):
|
||||
name = 'nn_tagger'
|
||||
|
|
|
@ -2,12 +2,38 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from ..util import ensure_path
|
||||
from .._ml import model_to_bytes, model_from_bytes
|
||||
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
from thinc.neural import Maxout, Softmax
|
||||
from thinc.api import chain
|
||||
|
||||
|
||||
@pytest.mark.parametrize('text', ['hello/world', 'hello world'])
|
||||
def test_util_ensure_path_succeeds(text):
|
||||
path = ensure_path(text)
|
||||
assert isinstance(path, Path)
|
||||
|
||||
|
||||
def test_simple_model_roundtrip_bytes():
|
||||
model = Maxout(5, 10, pieces=2)
|
||||
model.b += 1
|
||||
data = model_to_bytes(model)
|
||||
model.b -= 1
|
||||
model_from_bytes(model, data)
|
||||
assert model.b[0, 0] == 1
|
||||
|
||||
|
||||
def test_multi_model_roundtrip_bytes():
|
||||
model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3))
|
||||
model._layers[0].b += 1
|
||||
model._layers[1].b += 2
|
||||
data = model_to_bytes(model)
|
||||
model._layers[0].b -= 1
|
||||
model._layers[1].b -= 2
|
||||
model_from_bytes(model, data)
|
||||
assert model._layers[0].b[0, 0] == 1
|
||||
assert model._layers[1].b[0, 0] == 2
|
||||
|
||||
|
||||
|
|
|
@ -408,6 +408,18 @@ def get_raw_input(description, default=False):
|
|||
return user_input
|
||||
|
||||
|
||||
def to_bytes(unserialized, exclude):
|
||||
serialized = {}
|
||||
for key, value in unserialized.items():
|
||||
if key in exclude:
|
||||
continue
|
||||
elif hasattr(value, 'to_bytes'):
|
||||
serialized[key] = value.to_bytes()
|
||||
else:
|
||||
serialized[key] = ujson.dumps(value)
|
||||
return ujson.dumps(serialized)
|
||||
|
||||
|
||||
def print_table(data, title=None):
|
||||
"""Print data in table format.
|
||||
|
||||
|
|
|
@ -56,15 +56,7 @@ cdef class Vocab:
|
|||
if strings:
|
||||
for string in strings:
|
||||
self.strings.add(string)
|
||||
# Load strings in a special order, so that we have an onset number for
|
||||
# the vocabulary. This way, when words are added in order, the orth ID
|
||||
# is the frequency rank of the word, plus a certain offset. The structural
|
||||
# strings are loaded first, because the vocab is open-class, and these
|
||||
# symbols are closed class.
|
||||
# TODO: Actually this has turned out to be a pain in the ass...
|
||||
# It means the data is invalidated when we add a symbol :(
|
||||
# Need to rethink this.
|
||||
for name in symbols.NAMES + list(sorted(tag_map.keys())):
|
||||
for name in tag_map.keys():
|
||||
if name:
|
||||
self.strings.add(name)
|
||||
self.lex_attr_getters = lex_attr_getters
|
||||
|
|
Loading…
Reference in New Issue
Block a user