Update to/from bytes methods

This commit is contained in:
Matthew Honnibal 2017-05-29 10:14:20 +02:00
parent c91b121aeb
commit 6b019b0540
3 changed files with 67 additions and 28 deletions

View File

@ -9,7 +9,6 @@ import numpy
cimport numpy as np cimport numpy as np
import cytoolz import cytoolz
import util import util
import ujson
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
from thinc.neural import Model, Maxout, Softmax, Affine from thinc.neural import Model, Maxout, Softmax, Affine
@ -160,18 +159,18 @@ class TokenVectorEncoder(object):
yield yield
def to_bytes(self, **exclude): def to_bytes(self, **exclude):
data = { serialize = {
'model': self.model, 'model': lambda: model_to_bytes(self.model),
'vocab': self.vocab 'vocab': lambda: self.vocab.to_bytes()
} }
return util.to_bytes(data, exclude) return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, **exclude): def from_bytes(self, bytes_data, **exclude):
data = ujson.loads(bytes_data) deserialize = {
if 'model' not in exclude: 'model': lambda b: model_from_bytes(self.model, b),
util.model_from_bytes(self.model, data['model']) 'vocab': lambda b: self.vocab.from_bytes(b)
if 'vocab' not in exclude: }
self.vocab.from_bytes(data['vocab']) util.from_bytes(deserialize, exclude)
return self return self
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
@ -290,6 +289,23 @@ class NeuralTagger(object):
with self.model.use_params(params): with self.model.use_params(params):
yield yield
def to_bytes(self, **exclude):
serialize = {
'model': lambda: model_to_bytes(self.model),
'vocab': lambda: self.vocab.to_bytes()
}
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, **exclude):
deserialize = {
'model': lambda b: model_from_bytes(self.model, b),
'vocab': lambda b: self.vocab.from_bytes(b)
}
util.from_bytes(deserialize, exclude)
return self
class NeuralLabeller(NeuralTagger): class NeuralLabeller(NeuralTagger):
name = 'nn_labeller' name = 'nn_labeller'
def __init__(self, vocab, model=True): def __init__(self, vocab, model=True):

View File

@ -260,7 +260,14 @@ cdef class Parser:
# Used to set input dimensions in network. # Used to set input dimensions in network.
lower.begin_training(lower.ops.allocate((500, token_vector_width))) lower.begin_training(lower.ops.allocate((500, token_vector_width)))
upper.begin_training(upper.ops.allocate((500, hidden_width))) upper.begin_training(upper.ops.allocate((500, hidden_width)))
return lower, upper cfg = {
'nr_class': nr_class,
'depth': depth,
'token_vector_width': token_vector_width,
'hidden_width': hidden_width,
'maxout_pieces': parser_maxout_pieces
}
return (lower, upper), cfg
def __init__(self, Vocab vocab, moves=True, model=True, **cfg): def __init__(self, Vocab vocab, moves=True, model=True, **cfg):
""" """
@ -611,7 +618,8 @@ cdef class Parser:
for label in labels: for label in labels:
self.moves.add_action(action, label) self.moves.add_action(action, label)
if self.model is True: if self.model is True:
self.model = self.Model(self.moves.n_moves, **cfg) self.model, cfg = self.Model(self.moves.n_moves, **cfg)
self.cfg.update(cfg)
def preprocess_gold(self, docs_golds): def preprocess_gold(self, docs_golds):
for doc, gold in docs_golds: for doc, gold in docs_golds:
@ -633,11 +641,28 @@ cdef class Parser:
with (path / 'model.bin').open('wb') as file_: with (path / 'model.bin').open('wb') as file_:
self.model = dill.load(file_) self.model = dill.load(file_)
def to_bytes(self): def to_bytes(self, **exclude):
dill.dumps(self.model) serialize = {
'model': lambda: util.model_to_bytes(self.model),
'vocab': lambda: self.vocab.to_bytes(),
'moves': lambda: self.moves.to_bytes(),
'cfg': lambda: ujson.dumps(self.cfg)
}
return util.to_bytes(serialize, exclude)
def from_bytes(self, data): def from_bytes(self, bytes_data, **exclude):
self.model = dill.loads(data) deserialize = {
'vocab': lambda b: self.vocab.from_bytes(b),
'moves': lambda b: self.moves.from_bytes(b),
'cfg': lambda b: self.cfg.update(ujson.loads(b)),
'model': lambda b: None
}
msg = util.from_bytes(deserialize, exclude)
if 'model' not in exclude:
if self.model is True:
self.model = self.Model(**msg['cfg'])
util.model_from_disk(self.model, msg['model'])
return self
class ParserStateError(ValueError): class ParserStateError(ValueError):

View File

@ -291,12 +291,11 @@ cdef class Vocab:
**exclude: Named attributes to prevent from being serialized. **exclude: Named attributes to prevent from being serialized.
RETURNS (bytes): The serialized form of the `Vocab` object. RETURNS (bytes): The serialized form of the `Vocab` object.
""" """
data = {} getters = {
if 'strings' not in exclude: 'strings': lambda: self.strings.to_bytes(),
data['strings'] = self.strings.to_bytes() 'lexemes': lambda: self.lexemes_to_bytes()
if 'lexemes' not in exclude: }
data['lexemes'] = self.lexemes_to_bytes return util.to_bytes(getters, exclude)
return ujson.dumps(data)
def from_bytes(self, bytes_data, **exclude): def from_bytes(self, bytes_data, **exclude):
"""Load state from a binary string. """Load state from a binary string.
@ -305,12 +304,11 @@ cdef class Vocab:
**exclude: Named attributes to prevent from being loaded. **exclude: Named attributes to prevent from being loaded.
RETURNS (Vocab): The `Vocab` object. RETURNS (Vocab): The `Vocab` object.
""" """
data = ujson.loads(bytes_data) setters = {
if 'strings' not in exclude: 'strings': lambda b: self.strings.from_bytes(b),
self.strings.from_bytes(data['strings']) 'lexemes': lambda b: self.lexemes_from_bytes(b)
if 'lexemes' not in exclude: }
self.lexemes_from_bytes(data['lexemes']) return util.from_bytes(bytes_data, setters, exclude)
return self
def lexemes_to_bytes(self): def lexemes_to_bytes(self):
cdef hash_t key cdef hash_t key