mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Update to/from bytes methods
This commit is contained in:
parent
c91b121aeb
commit
6b019b0540
|
@ -9,7 +9,6 @@ import numpy
|
|||
cimport numpy as np
|
||||
import cytoolz
|
||||
import util
|
||||
import ujson
|
||||
|
||||
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||
|
@ -160,18 +159,18 @@ class TokenVectorEncoder(object):
|
|||
yield
|
||||
|
||||
def to_bytes(self, **exclude):
|
||||
data = {
|
||||
'model': self.model,
|
||||
'vocab': self.vocab
|
||||
serialize = {
|
||||
'model': lambda: model_to_bytes(self.model),
|
||||
'vocab': lambda: self.vocab.to_bytes()
|
||||
}
|
||||
return util.to_bytes(data, exclude)
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
data = ujson.loads(bytes_data)
|
||||
if 'model' not in exclude:
|
||||
util.model_from_bytes(self.model, data['model'])
|
||||
if 'vocab' not in exclude:
|
||||
self.vocab.from_bytes(data['vocab'])
|
||||
deserialize = {
|
||||
'model': lambda b: model_from_bytes(self.model, b),
|
||||
'vocab': lambda b: self.vocab.from_bytes(b)
|
||||
}
|
||||
util.from_bytes(deserialize, exclude)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
|
@ -290,6 +289,23 @@ class NeuralTagger(object):
|
|||
with self.model.use_params(params):
|
||||
yield
|
||||
|
||||
def to_bytes(self, **exclude):
|
||||
serialize = {
|
||||
'model': lambda: model_to_bytes(self.model),
|
||||
'vocab': lambda: self.vocab.to_bytes()
|
||||
}
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
deserialize = {
|
||||
'model': lambda b: model_from_bytes(self.model, b),
|
||||
'vocab': lambda b: self.vocab.from_bytes(b)
|
||||
}
|
||||
util.from_bytes(deserialize, exclude)
|
||||
return self
|
||||
|
||||
|
||||
|
||||
class NeuralLabeller(NeuralTagger):
|
||||
name = 'nn_labeller'
|
||||
def __init__(self, vocab, model=True):
|
||||
|
|
|
@ -260,7 +260,14 @@ cdef class Parser:
|
|||
# Used to set input dimensions in network.
|
||||
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
|
||||
upper.begin_training(upper.ops.allocate((500, hidden_width)))
|
||||
return lower, upper
|
||||
cfg = {
|
||||
'nr_class': nr_class,
|
||||
'depth': depth,
|
||||
'token_vector_width': token_vector_width,
|
||||
'hidden_width': hidden_width,
|
||||
'maxout_pieces': parser_maxout_pieces
|
||||
}
|
||||
return (lower, upper), cfg
|
||||
|
||||
def __init__(self, Vocab vocab, moves=True, model=True, **cfg):
|
||||
"""
|
||||
|
@ -611,7 +618,8 @@ cdef class Parser:
|
|||
for label in labels:
|
||||
self.moves.add_action(action, label)
|
||||
if self.model is True:
|
||||
self.model = self.Model(self.moves.n_moves, **cfg)
|
||||
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
|
||||
self.cfg.update(cfg)
|
||||
|
||||
def preprocess_gold(self, docs_golds):
|
||||
for doc, gold in docs_golds:
|
||||
|
@ -633,11 +641,28 @@ cdef class Parser:
|
|||
with (path / 'model.bin').open('wb') as file_:
|
||||
self.model = dill.load(file_)
|
||||
|
||||
def to_bytes(self):
|
||||
dill.dumps(self.model)
|
||||
def to_bytes(self, **exclude):
|
||||
serialize = {
|
||||
'model': lambda: util.model_to_bytes(self.model),
|
||||
'vocab': lambda: self.vocab.to_bytes(),
|
||||
'moves': lambda: self.moves.to_bytes(),
|
||||
'cfg': lambda: ujson.dumps(self.cfg)
|
||||
}
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, data):
|
||||
self.model = dill.loads(data)
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
deserialize = {
|
||||
'vocab': lambda b: self.vocab.from_bytes(b),
|
||||
'moves': lambda b: self.moves.from_bytes(b),
|
||||
'cfg': lambda b: self.cfg.update(ujson.loads(b)),
|
||||
'model': lambda b: None
|
||||
}
|
||||
msg = util.from_bytes(deserialize, exclude)
|
||||
if 'model' not in exclude:
|
||||
if self.model is True:
|
||||
self.model = self.Model(**msg['cfg'])
|
||||
util.model_from_disk(self.model, msg['model'])
|
||||
return self
|
||||
|
||||
|
||||
class ParserStateError(ValueError):
|
||||
|
|
|
@ -291,12 +291,11 @@ cdef class Vocab:
|
|||
**exclude: Named attributes to prevent from being serialized.
|
||||
RETURNS (bytes): The serialized form of the `Vocab` object.
|
||||
"""
|
||||
data = {}
|
||||
if 'strings' not in exclude:
|
||||
data['strings'] = self.strings.to_bytes()
|
||||
if 'lexemes' not in exclude:
|
||||
data['lexemes'] = self.lexemes_to_bytes
|
||||
return ujson.dumps(data)
|
||||
getters = {
|
||||
'strings': lambda: self.strings.to_bytes(),
|
||||
'lexemes': lambda: self.lexemes_to_bytes()
|
||||
}
|
||||
return util.to_bytes(getters, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
"""Load state from a binary string.
|
||||
|
@ -305,12 +304,11 @@ cdef class Vocab:
|
|||
**exclude: Named attributes to prevent from being loaded.
|
||||
RETURNS (Vocab): The `Vocab` object.
|
||||
"""
|
||||
data = ujson.loads(bytes_data)
|
||||
if 'strings' not in exclude:
|
||||
self.strings.from_bytes(data['strings'])
|
||||
if 'lexemes' not in exclude:
|
||||
self.lexemes_from_bytes(data['lexemes'])
|
||||
return self
|
||||
setters = {
|
||||
'strings': lambda b: self.strings.from_bytes(b),
|
||||
'lexemes': lambda b: self.lexemes_from_bytes(b)
|
||||
}
|
||||
return util.from_bytes(bytes_data, setters, exclude)
|
||||
|
||||
def lexemes_to_bytes(self):
|
||||
cdef hash_t key
|
||||
|
|
Loading…
Reference in New Issue
Block a user