mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Work on serialization for models
This commit is contained in:
parent
b007b0e5a0
commit
6dad4117ad
39
spacy/_ml.py
39
spacy/_ml.py
|
@ -1,3 +1,4 @@
|
||||||
|
import ujson
|
||||||
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
||||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||||
from thinc.neural._classes.hash_embed import HashEmbed
|
from thinc.neural._classes.hash_embed import HashEmbed
|
||||||
|
@ -15,9 +16,47 @@ from thinc.neural._classes.affine import _set_dimensions_if_needed
|
||||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||||
from .tokens.doc import Doc
|
from .tokens.doc import Doc
|
||||||
|
|
||||||
|
import dill
|
||||||
import numpy
|
import numpy
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
def model_to_bytes(model):
|
||||||
|
weights = []
|
||||||
|
metas = []
|
||||||
|
queue = [model]
|
||||||
|
i = 0
|
||||||
|
for layer in queue:
|
||||||
|
if hasattr(layer, '_mem'):
|
||||||
|
weights.append(layer._mem.weights)
|
||||||
|
metas.append(layer._mem._offsets)
|
||||||
|
i += 1
|
||||||
|
if hasattr(layer, '_layers'):
|
||||||
|
queue.extend(layer._layers)
|
||||||
|
data = {'metas': metas, 'weights': weights}
|
||||||
|
# TODO: Replace the pickle here with something else
|
||||||
|
return dill.dumps(data)
|
||||||
|
|
||||||
|
|
||||||
|
def model_from_bytes(model, bytes_data):
|
||||||
|
# TODO: Replace the pickle here with something else
|
||||||
|
data = dill.loads(bytes_data)
|
||||||
|
metas = data['metas']
|
||||||
|
weights = data['weights']
|
||||||
|
queue = [model]
|
||||||
|
i = 0
|
||||||
|
for layer in queue:
|
||||||
|
if hasattr(layer, '_mem'):
|
||||||
|
params = weights[i]
|
||||||
|
flat_mem = layer._mem._mem.ravel()
|
||||||
|
flat_params = params.ravel()
|
||||||
|
flat_mem[:flat_params.size] = flat_params
|
||||||
|
layer._mem._offsets.update(metas[i])
|
||||||
|
i += 1
|
||||||
|
if hasattr(layer, '_layers'):
|
||||||
|
queue.extend(layer._layers)
|
||||||
|
|
||||||
|
|
||||||
def _init_for_precomputed(W, ops):
|
def _init_for_precomputed(W, ops):
|
||||||
if (W**2).sum() != 0.:
|
if (W**2).sum() != 0.:
|
||||||
return
|
return
|
||||||
|
|
|
@ -9,6 +9,7 @@ import numpy
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
import cytoolz
|
import cytoolz
|
||||||
import util
|
import util
|
||||||
|
import ujson
|
||||||
|
|
||||||
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
||||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||||
|
@ -35,6 +36,7 @@ from .syntax import nonproj
|
||||||
|
|
||||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
|
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
|
||||||
from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats
|
from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats
|
||||||
|
from ._ml import model_to_bytes, model_from_bytes
|
||||||
from .parts_of_speech import X
|
from .parts_of_speech import X
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,7 +150,6 @@ class TokenVectorEncoder(object):
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model()
|
self.model = self.Model()
|
||||||
|
|
||||||
|
|
||||||
def use_params(self, params):
|
def use_params(self, params):
|
||||||
"""Replace weights of models in the pipeline with those provided in the
|
"""Replace weights of models in the pipeline with those provided in the
|
||||||
params dictionary.
|
params dictionary.
|
||||||
|
@ -158,6 +159,39 @@ class TokenVectorEncoder(object):
|
||||||
with self.model.use_params(params):
|
with self.model.use_params(params):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
def to_bytes(self, **exclude):
|
||||||
|
data = {
|
||||||
|
'model': self.model,
|
||||||
|
'vocab': self.vocab
|
||||||
|
}
|
||||||
|
return util.to_bytes(data, exclude)
|
||||||
|
|
||||||
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
|
data = ujson.loads(bytes_data)
|
||||||
|
if 'model' not in exclude:
|
||||||
|
util.model_from_bytes(self.model, data['model'])
|
||||||
|
if 'vocab' not in exclude:
|
||||||
|
self.vocab.from_bytes(data['vocab'])
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_disk(self, path, **exclude):
|
||||||
|
path = util.ensure_path(path)
|
||||||
|
if not path.exists():
|
||||||
|
path.mkdir()
|
||||||
|
if 'vocab' not in exclude:
|
||||||
|
self.vocab.to_disk(path / 'vocab')
|
||||||
|
if 'model' not in exclude:
|
||||||
|
with (path / 'model.bin').open('wb') as file_:
|
||||||
|
file_.write(util.model_to_bytes(self.model))
|
||||||
|
|
||||||
|
def from_disk(self, path, **exclude):
|
||||||
|
path = util.ensure_path(path)
|
||||||
|
if 'vocab' not in exclude:
|
||||||
|
self.vocab.from_disk(path / 'vocab')
|
||||||
|
if 'model.bin' not in exclude:
|
||||||
|
with (path / 'model.bin').open('rb') as file_:
|
||||||
|
util.model_from_bytes(self.model, file_.read())
|
||||||
|
|
||||||
|
|
||||||
class NeuralTagger(object):
|
class NeuralTagger(object):
|
||||||
name = 'nn_tagger'
|
name = 'nn_tagger'
|
||||||
|
|
|
@ -2,12 +2,38 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from ..util import ensure_path
|
from ..util import ensure_path
|
||||||
|
from .._ml import model_to_bytes, model_from_bytes
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
|
from thinc.neural import Maxout, Softmax
|
||||||
|
from thinc.api import chain
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('text', ['hello/world', 'hello world'])
|
@pytest.mark.parametrize('text', ['hello/world', 'hello world'])
|
||||||
def test_util_ensure_path_succeeds(text):
|
def test_util_ensure_path_succeeds(text):
|
||||||
path = ensure_path(text)
|
path = ensure_path(text)
|
||||||
assert isinstance(path, Path)
|
assert isinstance(path, Path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_model_roundtrip_bytes():
|
||||||
|
model = Maxout(5, 10, pieces=2)
|
||||||
|
model.b += 1
|
||||||
|
data = model_to_bytes(model)
|
||||||
|
model.b -= 1
|
||||||
|
model_from_bytes(model, data)
|
||||||
|
assert model.b[0, 0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_model_roundtrip_bytes():
|
||||||
|
model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3))
|
||||||
|
model._layers[0].b += 1
|
||||||
|
model._layers[1].b += 2
|
||||||
|
data = model_to_bytes(model)
|
||||||
|
model._layers[0].b -= 1
|
||||||
|
model._layers[1].b -= 2
|
||||||
|
model_from_bytes(model, data)
|
||||||
|
assert model._layers[0].b[0, 0] == 1
|
||||||
|
assert model._layers[1].b[0, 0] == 2
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -408,6 +408,18 @@ def get_raw_input(description, default=False):
|
||||||
return user_input
|
return user_input
|
||||||
|
|
||||||
|
|
||||||
|
def to_bytes(unserialized, exclude):
|
||||||
|
serialized = {}
|
||||||
|
for key, value in unserialized.items():
|
||||||
|
if key in exclude:
|
||||||
|
continue
|
||||||
|
elif hasattr(value, 'to_bytes'):
|
||||||
|
serialized[key] = value.to_bytes()
|
||||||
|
else:
|
||||||
|
serialized[key] = ujson.dumps(value)
|
||||||
|
return ujson.dumps(serialized)
|
||||||
|
|
||||||
|
|
||||||
def print_table(data, title=None):
|
def print_table(data, title=None):
|
||||||
"""Print data in table format.
|
"""Print data in table format.
|
||||||
|
|
||||||
|
|
|
@ -56,15 +56,7 @@ cdef class Vocab:
|
||||||
if strings:
|
if strings:
|
||||||
for string in strings:
|
for string in strings:
|
||||||
self.strings.add(string)
|
self.strings.add(string)
|
||||||
# Load strings in a special order, so that we have an onset number for
|
for name in tag_map.keys():
|
||||||
# the vocabulary. This way, when words are added in order, the orth ID
|
|
||||||
# is the frequency rank of the word, plus a certain offset. The structural
|
|
||||||
# strings are loaded first, because the vocab is open-class, and these
|
|
||||||
# symbols are closed class.
|
|
||||||
# TODO: Actually this has turned out to be a pain in the ass...
|
|
||||||
# It means the data is invalidated when we add a symbol :(
|
|
||||||
# Need to rethink this.
|
|
||||||
for name in symbols.NAMES + list(sorted(tag_map.keys())):
|
|
||||||
if name:
|
if name:
|
||||||
self.strings.add(name)
|
self.strings.add(name)
|
||||||
self.lex_attr_getters = lex_attr_getters
|
self.lex_attr_getters = lex_attr_getters
|
||||||
|
|
Loading…
Reference in New Issue
Block a user