mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Update to/from bytes methods
This commit is contained in:
parent
c91b121aeb
commit
6b019b0540
|
@ -9,7 +9,6 @@ import numpy
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
import cytoolz
|
import cytoolz
|
||||||
import util
|
import util
|
||||||
import ujson
|
|
||||||
|
|
||||||
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
||||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||||
|
@ -160,18 +159,18 @@ class TokenVectorEncoder(object):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def to_bytes(self, **exclude):
|
def to_bytes(self, **exclude):
|
||||||
data = {
|
serialize = {
|
||||||
'model': self.model,
|
'model': lambda: model_to_bytes(self.model),
|
||||||
'vocab': self.vocab
|
'vocab': lambda: self.vocab.to_bytes()
|
||||||
}
|
}
|
||||||
return util.to_bytes(data, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, **exclude):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
data = ujson.loads(bytes_data)
|
deserialize = {
|
||||||
if 'model' not in exclude:
|
'model': lambda b: model_from_bytes(self.model, b),
|
||||||
util.model_from_bytes(self.model, data['model'])
|
'vocab': lambda b: self.vocab.from_bytes(b)
|
||||||
if 'vocab' not in exclude:
|
}
|
||||||
self.vocab.from_bytes(data['vocab'])
|
util.from_bytes(deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
|
@ -290,6 +289,23 @@ class NeuralTagger(object):
|
||||||
with self.model.use_params(params):
|
with self.model.use_params(params):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
def to_bytes(self, **exclude):
|
||||||
|
serialize = {
|
||||||
|
'model': lambda: model_to_bytes(self.model),
|
||||||
|
'vocab': lambda: self.vocab.to_bytes()
|
||||||
|
}
|
||||||
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
|
deserialize = {
|
||||||
|
'model': lambda b: model_from_bytes(self.model, b),
|
||||||
|
'vocab': lambda b: self.vocab.from_bytes(b)
|
||||||
|
}
|
||||||
|
util.from_bytes(deserialize, exclude)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class NeuralLabeller(NeuralTagger):
|
class NeuralLabeller(NeuralTagger):
|
||||||
name = 'nn_labeller'
|
name = 'nn_labeller'
|
||||||
def __init__(self, vocab, model=True):
|
def __init__(self, vocab, model=True):
|
||||||
|
|
|
@ -260,7 +260,14 @@ cdef class Parser:
|
||||||
# Used to set input dimensions in network.
|
# Used to set input dimensions in network.
|
||||||
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
|
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
|
||||||
upper.begin_training(upper.ops.allocate((500, hidden_width)))
|
upper.begin_training(upper.ops.allocate((500, hidden_width)))
|
||||||
return lower, upper
|
cfg = {
|
||||||
|
'nr_class': nr_class,
|
||||||
|
'depth': depth,
|
||||||
|
'token_vector_width': token_vector_width,
|
||||||
|
'hidden_width': hidden_width,
|
||||||
|
'maxout_pieces': parser_maxout_pieces
|
||||||
|
}
|
||||||
|
return (lower, upper), cfg
|
||||||
|
|
||||||
def __init__(self, Vocab vocab, moves=True, model=True, **cfg):
|
def __init__(self, Vocab vocab, moves=True, model=True, **cfg):
|
||||||
"""
|
"""
|
||||||
|
@ -611,7 +618,8 @@ cdef class Parser:
|
||||||
for label in labels:
|
for label in labels:
|
||||||
self.moves.add_action(action, label)
|
self.moves.add_action(action, label)
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model(self.moves.n_moves, **cfg)
|
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
|
||||||
|
self.cfg.update(cfg)
|
||||||
|
|
||||||
def preprocess_gold(self, docs_golds):
|
def preprocess_gold(self, docs_golds):
|
||||||
for doc, gold in docs_golds:
|
for doc, gold in docs_golds:
|
||||||
|
@ -633,11 +641,28 @@ cdef class Parser:
|
||||||
with (path / 'model.bin').open('wb') as file_:
|
with (path / 'model.bin').open('wb') as file_:
|
||||||
self.model = dill.load(file_)
|
self.model = dill.load(file_)
|
||||||
|
|
||||||
def to_bytes(self):
|
def to_bytes(self, **exclude):
|
||||||
dill.dumps(self.model)
|
serialize = {
|
||||||
|
'model': lambda: util.model_to_bytes(self.model),
|
||||||
|
'vocab': lambda: self.vocab.to_bytes(),
|
||||||
|
'moves': lambda: self.moves.to_bytes(),
|
||||||
|
'cfg': lambda: ujson.dumps(self.cfg)
|
||||||
|
}
|
||||||
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(self, data):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
self.model = dill.loads(data)
|
deserialize = {
|
||||||
|
'vocab': lambda b: self.vocab.from_bytes(b),
|
||||||
|
'moves': lambda b: self.moves.from_bytes(b),
|
||||||
|
'cfg': lambda b: self.cfg.update(ujson.loads(b)),
|
||||||
|
'model': lambda b: None
|
||||||
|
}
|
||||||
|
msg = util.from_bytes(deserialize, exclude)
|
||||||
|
if 'model' not in exclude:
|
||||||
|
if self.model is True:
|
||||||
|
self.model = self.Model(**msg['cfg'])
|
||||||
|
util.model_from_disk(self.model, msg['model'])
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ParserStateError(ValueError):
|
class ParserStateError(ValueError):
|
||||||
|
|
|
@ -291,12 +291,11 @@ cdef class Vocab:
|
||||||
**exclude: Named attributes to prevent from being serialized.
|
**exclude: Named attributes to prevent from being serialized.
|
||||||
RETURNS (bytes): The serialized form of the `Vocab` object.
|
RETURNS (bytes): The serialized form of the `Vocab` object.
|
||||||
"""
|
"""
|
||||||
data = {}
|
getters = {
|
||||||
if 'strings' not in exclude:
|
'strings': lambda: self.strings.to_bytes(),
|
||||||
data['strings'] = self.strings.to_bytes()
|
'lexemes': lambda: self.lexemes_to_bytes()
|
||||||
if 'lexemes' not in exclude:
|
}
|
||||||
data['lexemes'] = self.lexemes_to_bytes
|
return util.to_bytes(getters, exclude)
|
||||||
return ujson.dumps(data)
|
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, **exclude):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
"""Load state from a binary string.
|
"""Load state from a binary string.
|
||||||
|
@ -305,12 +304,11 @@ cdef class Vocab:
|
||||||
**exclude: Named attributes to prevent from being loaded.
|
**exclude: Named attributes to prevent from being loaded.
|
||||||
RETURNS (Vocab): The `Vocab` object.
|
RETURNS (Vocab): The `Vocab` object.
|
||||||
"""
|
"""
|
||||||
data = ujson.loads(bytes_data)
|
setters = {
|
||||||
if 'strings' not in exclude:
|
'strings': lambda b: self.strings.from_bytes(b),
|
||||||
self.strings.from_bytes(data['strings'])
|
'lexemes': lambda b: self.lexemes_from_bytes(b)
|
||||||
if 'lexemes' not in exclude:
|
}
|
||||||
self.lexemes_from_bytes(data['lexemes'])
|
return util.from_bytes(bytes_data, setters, exclude)
|
||||||
return self
|
|
||||||
|
|
||||||
def lexemes_to_bytes(self):
|
def lexemes_to_bytes(self):
|
||||||
cdef hash_t key
|
cdef hash_t key
|
||||||
|
|
Loading…
Reference in New Issue
Block a user