mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Work on to/from bytes/disk serialization methods
This commit is contained in:
parent
6b019b0540
commit
ff26aa6c37
|
@ -366,20 +366,22 @@ class Language(object):
|
|||
>>> nlp.to_disk('/path/to/models')
|
||||
"""
|
||||
path = util.ensure_path(path)
|
||||
if not path.exists():
|
||||
path.mkdir()
|
||||
if not path.is_dir():
|
||||
raise IOError("Output path must be a directory")
|
||||
props = {}
|
||||
for name, value in self.__dict__.items():
|
||||
if name in disable:
|
||||
continue
|
||||
if hasattr(value, 'to_disk'):
|
||||
value.to_disk(path / name)
|
||||
else:
|
||||
props[name] = value
|
||||
with (path / 'props.pickle').open('wb') as file_:
|
||||
dill.dump(props, file_)
|
||||
with path.open('wb') as file_:
|
||||
file_.write(self.to_bytes(disable))
|
||||
#serializers = {
|
||||
# 'vocab': lambda p: self.vocab.to_disk(p),
|
||||
# 'tokenizer': lambda p: self.tokenizer.to_disk(p, vocab=False),
|
||||
# 'meta.json': lambda p: ujson.dump(p.open('w'), self.meta)
|
||||
#}
|
||||
#for proc in self.pipeline:
|
||||
# if not hasattr(proc, 'name'):
|
||||
# continue
|
||||
# if proc.name in disable:
|
||||
# continue
|
||||
# if not hasattr(proc, 'to_disk'):
|
||||
# continue
|
||||
# serializers[proc.name] = lambda p: proc.to_disk(p, vocab=False)
|
||||
#util.to_disk(serializers, path)
|
||||
|
||||
def from_disk(self, path, disable=[]):
|
||||
"""Loads state from a directory. Modifies the object in place and
|
||||
|
@ -396,13 +398,24 @@ class Language(object):
|
|||
>>> nlp = Language().from_disk('/path/to/models')
|
||||
"""
|
||||
path = util.ensure_path(path)
|
||||
for name in path.iterdir():
|
||||
if name not in disable and hasattr(self, str(name)):
|
||||
getattr(self, name).from_disk(path / name)
|
||||
with (path / 'props.pickle').open('rb') as file_:
|
||||
with path.open('rb') as file_:
|
||||
bytes_data = file_.read()
|
||||
self.from_bytes(bytes_data, disable)
|
||||
return self
|
||||
return self.from_bytes(bytes_data, disable)
|
||||
#deserializers = {
|
||||
# 'vocab': lambda p: self.vocab.from_disk(p),
|
||||
# 'tokenizer': lambda p: self.tokenizer.from_disk(p, vocab=False),
|
||||
# 'meta.json': lambda p: ujson.dump(p.open('w'), self.meta)
|
||||
#}
|
||||
#for proc in self.pipeline:
|
||||
# if not hasattr(proc, 'name'):
|
||||
# continue
|
||||
# if proc.name in disable:
|
||||
# continue
|
||||
# if not hasattr(proc, 'to_disk'):
|
||||
# continue
|
||||
# deserializers[proc.name] = lambda p: proc.from_disk(p, vocab=False)
|
||||
#util.from_disk(deserializers, path)
|
||||
#return self
|
||||
|
||||
def to_bytes(self, disable=[]):
|
||||
"""Serialize the current state to a binary string.
|
||||
|
@ -411,11 +424,20 @@ class Language(object):
|
|||
from being serialized.
|
||||
RETURNS (bytes): The serialized form of the `Language` object.
|
||||
"""
|
||||
props = dict(self.__dict__)
|
||||
for key in disable:
|
||||
if key in props:
|
||||
props.pop(key)
|
||||
return dill.dumps(props, -1)
|
||||
serializers = {
|
||||
'vocab': lambda: self.vocab.to_bytes(),
|
||||
'tokenizer': lambda: self.tokenizer.to_bytes(vocab=False),
|
||||
'meta': lambda: ujson.dumps(self.meta)
|
||||
}
|
||||
for proc in self.pipeline:
|
||||
if not hasattr(proc, 'name'):
|
||||
continue
|
||||
if proc.name in disable:
|
||||
continue
|
||||
if not hasattr(proc, 'to_bytes'):
|
||||
continue
|
||||
serializers[proc.name] = lambda: proc.to_bytes(p, vocab=False)
|
||||
return util.to_bytes(serializers)
|
||||
|
||||
def from_bytes(self, bytes_data, disable=[]):
|
||||
"""Load state from a binary string.
|
||||
|
@ -424,12 +446,23 @@ class Language(object):
|
|||
disable (list): Names of the pipeline components to disable.
|
||||
RETURNS (Language): The `Language` object.
|
||||
"""
|
||||
props = dill.loads(bytes_data)
|
||||
for key, value in props.items():
|
||||
if key not in disable:
|
||||
setattr(self, key, value)
|
||||
deserializers = {
|
||||
'vocab': lambda b: self.vocab.from_bytes(b),
|
||||
'tokenizer': lambda b: self.tokenizer.from_bytes(b, vocab=False),
|
||||
'meta': lambda b: self.meta.update(ujson.loads(b))
|
||||
}
|
||||
for proc in self.pipeline:
|
||||
if not hasattr(proc, 'name'):
|
||||
continue
|
||||
if proc.name in disable:
|
||||
continue
|
||||
if not hasattr(proc, 'to_disk'):
|
||||
continue
|
||||
deserializers[proc.name] = lambda b: proc.from_bytes(b, vocab=False)
|
||||
util.from_bytes(deserializers, bytes_data)
|
||||
return self
|
||||
|
||||
|
||||
def _pipe(func, docs):
|
||||
for doc in docs:
|
||||
func(doc)
|
||||
|
|
|
@ -35,7 +35,6 @@ from .syntax import nonproj
|
|||
|
||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
|
||||
from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats
|
||||
from ._ml import model_to_bytes, model_from_bytes
|
||||
from .parts_of_speech import X
|
||||
|
||||
|
||||
|
@ -160,36 +159,33 @@ class TokenVectorEncoder(object):
|
|||
|
||||
def to_bytes(self, **exclude):
|
||||
serialize = {
|
||||
'model': lambda: model_to_bytes(self.model),
|
||||
'model': lambda: util.model_to_bytes(self.model),
|
||||
'vocab': lambda: self.vocab.to_bytes()
|
||||
}
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
deserialize = {
|
||||
'model': lambda b: model_from_bytes(self.model, b),
|
||||
'model': lambda b: util.model_from_bytes(self.model, b),
|
||||
'vocab': lambda b: self.vocab.from_bytes(b)
|
||||
}
|
||||
util.from_bytes(deserialize, exclude)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
path = util.ensure_path(path)
|
||||
if not path.exists():
|
||||
path.mkdir()
|
||||
if 'vocab' not in exclude:
|
||||
self.vocab.to_disk(path / 'vocab')
|
||||
if 'model' not in exclude:
|
||||
with (path / 'model.bin').open('wb') as file_:
|
||||
file_.write(util.model_to_bytes(self.model))
|
||||
serialize = {
|
||||
'model': lambda p: p.open('w').write(util.model_to_bytes(self.model)),
|
||||
'vocab': lambda p: self.vocab.to_disk(p)
|
||||
}
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, **exclude):
|
||||
path = util.ensure_path(path)
|
||||
if 'vocab' not in exclude:
|
||||
self.vocab.from_disk(path / 'vocab')
|
||||
if 'model.bin' not in exclude:
|
||||
with (path / 'model.bin').open('rb') as file_:
|
||||
util.model_from_bytes(self.model, file_.read())
|
||||
deserialize = {
|
||||
'model': lambda p: util.model_from_bytes(self.model, p.open('rb').read()),
|
||||
'vocab': lambda p: self.vocab.from_disk(p)
|
||||
}
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
return self
|
||||
|
||||
|
||||
class NeuralTagger(object):
|
||||
|
@ -291,19 +287,33 @@ class NeuralTagger(object):
|
|||
|
||||
def to_bytes(self, **exclude):
|
||||
serialize = {
|
||||
'model': lambda: model_to_bytes(self.model),
|
||||
'model': lambda: util.model_to_bytes(self.model),
|
||||
'vocab': lambda: self.vocab.to_bytes()
|
||||
}
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
deserialize = {
|
||||
'model': lambda b: model_from_bytes(self.model, b),
|
||||
'model': lambda b: util.model_from_bytes(self.model, b),
|
||||
'vocab': lambda b: self.vocab.from_bytes(b)
|
||||
}
|
||||
util.from_bytes(deserialize, exclude)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
serialize = {
|
||||
'model': lambda p: p.open('w').write(util.model_to_bytes(self.model)),
|
||||
'vocab': lambda p: self.vocab.to_disk(p)
|
||||
}
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, **exclude):
|
||||
deserialize = {
|
||||
'model': lambda p: util.model_from_bytes(self.model, p.open('rb').read()),
|
||||
'vocab': lambda p: self.vocab.from_disk(p)
|
||||
}
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
return self
|
||||
|
||||
|
||||
class NeuralLabeller(NeuralTagger):
|
||||
|
|
|
@ -631,37 +631,53 @@ cdef class Parser:
|
|||
with self.model[1].use_params(params):
|
||||
yield
|
||||
|
||||
def to_disk(self, path):
|
||||
path = util.ensure_path(path)
|
||||
with (path / 'model.bin').open('wb') as file_:
|
||||
dill.dump(self.model, file_)
|
||||
def to_disk(self, path, **exclude):
|
||||
serializers = {
|
||||
'model': lambda p: p.open('wb').write(
|
||||
util.model_to_bytes(self.model)),
|
||||
'vocab': lambda p: self.vocab.to_disk(p),
|
||||
'moves': lambda p: self.moves.to_disk(p, strings=False),
|
||||
'cfg': lambda p: ujson.dumps(p.open('w'), self.cfg)
|
||||
}
|
||||
util.to_disk(path, serializers, exclude)
|
||||
|
||||
def from_disk(self, path):
|
||||
path = util.ensure_path(path)
|
||||
with (path / 'model.bin').open('wb') as file_:
|
||||
self.model = dill.load(file_)
|
||||
def from_disk(self, path, **exclude):
|
||||
deserializers = {
|
||||
'vocab': lambda p: self.vocab.from_disk(p),
|
||||
'moves': lambda p: self.moves.from_disk(p, strings=False),
|
||||
'cfg': lambda p: self.cfg.update(ujson.load((path/'cfg.json').open())),
|
||||
'model': lambda p: None
|
||||
}
|
||||
util.from_disk(path, deserializers, exclude)
|
||||
if 'model' not in exclude:
|
||||
path = util.ensure_path(path)
|
||||
if self.model is True:
|
||||
self.model = self.Model(**self.cfg)
|
||||
util.model_from_disk(self.model, path / 'model')
|
||||
return self
|
||||
|
||||
def to_bytes(self, **exclude):
|
||||
serialize = {
|
||||
serializers = {
|
||||
'model': lambda: util.model_to_bytes(self.model),
|
||||
'vocab': lambda: self.vocab.to_bytes(),
|
||||
'moves': lambda: self.moves.to_bytes(),
|
||||
'moves': lambda: self.moves.to_bytes(vocab=False),
|
||||
'cfg': lambda: ujson.dumps(self.cfg)
|
||||
}
|
||||
return util.to_bytes(serialize, exclude)
|
||||
return util.to_bytes(serializers, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
deserialize = {
|
||||
deserializers = {
|
||||
'vocab': lambda b: self.vocab.from_bytes(b),
|
||||
'moves': lambda b: self.moves.from_bytes(b),
|
||||
'cfg': lambda b: self.cfg.update(ujson.loads(b)),
|
||||
'model': lambda b: None
|
||||
}
|
||||
msg = util.from_bytes(deserialize, exclude)
|
||||
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
||||
if 'model' not in exclude:
|
||||
if self.model is True:
|
||||
self.model = self.Model(**msg['cfg'])
|
||||
util.model_from_disk(self.model, msg['model'])
|
||||
print(msg['cfg'])
|
||||
self.model = self.Model(self.moves.n_moves)
|
||||
util.model_from_bytes(self.model, msg['model'])
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,9 @@ from cpython.ref cimport PyObject, Py_INCREF, Py_XDECREF
|
|||
from cymem.cymem cimport Pool
|
||||
from thinc.typedefs cimport weight_t
|
||||
from collections import defaultdict, OrderedDict
|
||||
import ujson
|
||||
|
||||
from .. import util
|
||||
from ..structs cimport TokenC
|
||||
from .stateclass cimport StateClass
|
||||
from ..attrs cimport TAG, HEAD, DEP, ENT_TYPE, ENT_IOB
|
||||
|
@ -153,3 +155,48 @@ cdef class TransitionSystem:
|
|||
assert self.c[self.n_moves].label == label_id
|
||||
self.n_moves += 1
|
||||
return 1
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
actions = list(self.move_names)
|
||||
deserializers = {
|
||||
'actions': lambda p: ujson.dump(p.open('w'), actions),
|
||||
'strings': lambda p: self.strings.to_disk(p)
|
||||
}
|
||||
util.to_disk(path, deserializers, exclude)
|
||||
|
||||
def from_disk(self, path, **exclude):
|
||||
actions = []
|
||||
deserializers = {
|
||||
'strings': lambda p: self.strings.from_disk(p),
|
||||
'actions': lambda p: actions.extend(ujson.load(p.open()))
|
||||
}
|
||||
util.from_disk(path, deserializers, exclude)
|
||||
for move, label in actions:
|
||||
self.add_action(move, label)
|
||||
return self
|
||||
|
||||
def to_bytes(self, **exclude):
|
||||
transitions = []
|
||||
for trans in self.c[:self.n_moves]:
|
||||
transitions.append({
|
||||
'clas': trans.clas,
|
||||
'move': trans.move,
|
||||
'label': self.strings[trans.label],
|
||||
'name': self.move_name(trans.move, trans.label)
|
||||
})
|
||||
serializers = {
|
||||
'transitions': lambda: ujson.dumps(transitions),
|
||||
'strings': lambda: self.strings.to_bytes()
|
||||
}
|
||||
return util.to_bytes(serializers, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, **exclude):
|
||||
transitions = []
|
||||
deserializers = {
|
||||
'transitions': lambda b: transitions.extend(ujson.loads(b)),
|
||||
'strings': lambda b: self.strings.from_bytes(b)
|
||||
}
|
||||
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
||||
for trans in transitions:
|
||||
self.add_action(trans['move'], trans['label'])
|
||||
return self
|
||||
|
|
34
spacy/tests/parser/test_to_from_bytes_disk.py
Normal file
34
spacy/tests/parser/test_to_from_bytes_disk.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import pytest
|
||||
|
||||
from ...pipeline import NeuralDependencyParser
|
||||
from ...vocab import Vocab
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vocab():
|
||||
return Vocab()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(vocab):
|
||||
parser = NeuralDependencyParser(vocab)
|
||||
parser.add_label('nsubj')
|
||||
parser.model, cfg = parser.Model(parser.moves.n_moves)
|
||||
parser.cfg.update(cfg)
|
||||
return parser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def blank_parser(vocab):
|
||||
parser = NeuralDependencyParser(vocab)
|
||||
return parser
|
||||
|
||||
|
||||
def test_to_from_bytes(parser, blank_parser):
|
||||
assert parser.model is not True
|
||||
assert blank_parser.model is True
|
||||
assert blank_parser.moves.n_moves != parser.moves.n_moves
|
||||
bytes_data = parser.to_bytes()
|
||||
blank_parser.from_bytes(bytes_data)
|
||||
assert blank_parser.model is not True
|
||||
assert blank_parser.moves.n_moves == parser.moves.n_moves
|
|
@ -417,11 +417,11 @@ def to_bytes(getters, exclude):
|
|||
for key, getter in getters.items():
|
||||
if key not in exclude:
|
||||
serialized[key] = getter()
|
||||
return messagepack.dumps(serialized)
|
||||
return msgpack.dumps(serialized)
|
||||
|
||||
|
||||
def from_bytes(bytes_data, setters, exclude):
|
||||
msg = messagepack.loads(bytes_data)
|
||||
msg = msgpack.loads(bytes_data)
|
||||
for key, setter in setters.items():
|
||||
if key not in exclude:
|
||||
setter(msg[key])
|
||||
|
|
Loading…
Reference in New Issue
Block a user