Update to/from bytes methods

This commit is contained in:
Matthew Honnibal 2017-05-29 10:14:20 +02:00
parent c91b121aeb
commit 6b019b0540
3 changed files with 67 additions and 28 deletions

View File

@ -9,7 +9,6 @@ import numpy
cimport numpy as np
import cytoolz
import util
import ujson
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
from thinc.neural import Model, Maxout, Softmax, Affine
@ -160,18 +159,18 @@ class TokenVectorEncoder(object):
yield
def to_bytes(self, **exclude):
data = {
'model': self.model,
'vocab': self.vocab
serialize = {
'model': lambda: model_to_bytes(self.model),
'vocab': lambda: self.vocab.to_bytes()
}
return util.to_bytes(data, exclude)
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, **exclude):
data = ujson.loads(bytes_data)
if 'model' not in exclude:
util.model_from_bytes(self.model, data['model'])
if 'vocab' not in exclude:
self.vocab.from_bytes(data['vocab'])
deserialize = {
'model': lambda b: model_from_bytes(self.model, b),
'vocab': lambda b: self.vocab.from_bytes(b)
}
util.from_bytes(deserialize, exclude)
return self
def to_disk(self, path, **exclude):
@ -290,6 +289,23 @@ class NeuralTagger(object):
with self.model.use_params(params):
yield
def to_bytes(self, **exclude):
serialize = {
'model': lambda: model_to_bytes(self.model),
'vocab': lambda: self.vocab.to_bytes()
}
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, **exclude):
deserialize = {
'model': lambda b: model_from_bytes(self.model, b),
'vocab': lambda b: self.vocab.from_bytes(b)
}
util.from_bytes(deserialize, exclude)
return self
class NeuralLabeller(NeuralTagger):
name = 'nn_labeller'
def __init__(self, vocab, model=True):

View File

@ -260,7 +260,14 @@ cdef class Parser:
# Used to set input dimensions in network.
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
upper.begin_training(upper.ops.allocate((500, hidden_width)))
return lower, upper
cfg = {
'nr_class': nr_class,
'depth': depth,
'token_vector_width': token_vector_width,
'hidden_width': hidden_width,
'maxout_pieces': parser_maxout_pieces
}
return (lower, upper), cfg
def __init__(self, Vocab vocab, moves=True, model=True, **cfg):
"""
@ -611,7 +618,8 @@ cdef class Parser:
for label in labels:
self.moves.add_action(action, label)
if self.model is True:
self.model = self.Model(self.moves.n_moves, **cfg)
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
self.cfg.update(cfg)
def preprocess_gold(self, docs_golds):
for doc, gold in docs_golds:
@ -633,11 +641,28 @@ cdef class Parser:
with (path / 'model.bin').open('wb') as file_:
self.model = dill.load(file_)
def to_bytes(self):
dill.dumps(self.model)
def to_bytes(self, **exclude):
serialize = {
'model': lambda: util.model_to_bytes(self.model),
'vocab': lambda: self.vocab.to_bytes(),
'moves': lambda: self.moves.to_bytes(),
'cfg': lambda: ujson.dumps(self.cfg)
}
return util.to_bytes(serialize, exclude)
def from_bytes(self, data):
self.model = dill.loads(data)
def from_bytes(self, bytes_data, **exclude):
deserialize = {
'vocab': lambda b: self.vocab.from_bytes(b),
'moves': lambda b: self.moves.from_bytes(b),
'cfg': lambda b: self.cfg.update(ujson.loads(b)),
'model': lambda b: None
}
msg = util.from_bytes(deserialize, exclude)
if 'model' not in exclude:
if self.model is True:
self.model = self.Model(**msg['cfg'])
util.model_from_disk(self.model, msg['model'])
return self
class ParserStateError(ValueError):

View File

@ -291,12 +291,11 @@ cdef class Vocab:
**exclude: Named attributes to prevent from being serialized.
RETURNS (bytes): The serialized form of the `Vocab` object.
"""
data = {}
if 'strings' not in exclude:
data['strings'] = self.strings.to_bytes()
if 'lexemes' not in exclude:
data['lexemes'] = self.lexemes_to_bytes
return ujson.dumps(data)
getters = {
'strings': lambda: self.strings.to_bytes(),
'lexemes': lambda: self.lexemes_to_bytes()
}
return util.to_bytes(getters, exclude)
def from_bytes(self, bytes_data, **exclude):
"""Load state from a binary string.
@ -305,12 +304,11 @@ cdef class Vocab:
**exclude: Named attributes to prevent from being loaded.
RETURNS (Vocab): The `Vocab` object.
"""
data = ujson.loads(bytes_data)
if 'strings' not in exclude:
self.strings.from_bytes(data['strings'])
if 'lexemes' not in exclude:
self.lexemes_from_bytes(data['lexemes'])
return self
setters = {
'strings': lambda b: self.strings.from_bytes(b),
'lexemes': lambda b: self.lexemes_from_bytes(b)
}
return util.from_bytes(bytes_data, setters, exclude)
def lexemes_to_bytes(self):
cdef hash_t key