mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Fix to/from disk methods
This commit is contained in:
parent
5c30466c95
commit
33e5ec737f
|
@ -96,6 +96,13 @@ class BaseDefaults(object):
|
|||
|
||||
factories = {
|
||||
'make_doc': create_tokenizer,
|
||||
'tensorizer': lambda nlp, **cfg: [TokenVectorEncoder(nlp.vocab, **cfg)],
|
||||
'tagger': lambda nlp, **cfg: [NeuralTagger(nlp.vocab, **cfg)],
|
||||
'parser': lambda nlp, **cfg: [
|
||||
NeuralDependencyParser(nlp.vocab, **cfg),
|
||||
nonproj.deprojectivize],
|
||||
'ner': lambda nlp, **cfg: [NeuralEntityRecognizer(nlp.vocab, **cfg)],
|
||||
# Temporary compatibility -- delete after pivot
|
||||
'token_vectors': lambda nlp, **cfg: [TokenVectorEncoder(nlp.vocab, **cfg)],
|
||||
'tags': lambda nlp, **cfg: [NeuralTagger(nlp.vocab, **cfg)],
|
||||
'dependencies': lambda nlp, **cfg: [
|
||||
|
@ -358,37 +365,35 @@ class Language(object):
|
|||
for doc in docs:
|
||||
yield doc
|
||||
|
||||
def to_disk(self, path, disable=[]):
|
||||
def to_disk(self, path, disable=tuple()):
|
||||
"""Save the current state to a directory. If a model is loaded, this
|
||||
will include the model.
|
||||
|
||||
path (unicode or Path): A path to a directory, which will be created if
|
||||
it doesn't exist. Paths may be either strings or `Path`-like objects.
|
||||
disable (list): Nameds of pipeline components to disable and prevent
|
||||
disable (list): Names of pipeline components to disable and prevent
|
||||
from being saved.
|
||||
|
||||
EXAMPLE:
|
||||
>>> nlp.to_disk('/path/to/models')
|
||||
"""
|
||||
path = util.ensure_path(path)
|
||||
with path.open('wb') as file_:
|
||||
file_.write(self.to_bytes(disable))
|
||||
#serializers = {
|
||||
# 'vocab': lambda p: self.vocab.to_disk(p),
|
||||
# 'tokenizer': lambda p: self.tokenizer.to_disk(p, vocab=False),
|
||||
# 'meta.json': lambda p: ujson.dump(p.open('w'), self.meta)
|
||||
#}
|
||||
#for proc in self.pipeline:
|
||||
# if not hasattr(proc, 'name'):
|
||||
# continue
|
||||
# if proc.name in disable:
|
||||
# continue
|
||||
# if not hasattr(proc, 'to_disk'):
|
||||
# continue
|
||||
# serializers[proc.name] = lambda p: proc.to_disk(p, vocab=False)
|
||||
#util.to_disk(serializers, path)
|
||||
serializers = OrderedDict((
|
||||
('vocab', lambda p: self.vocab.to_disk(p)),
|
||||
('tokenizer', lambda p: self.tokenizer.to_disk(p, vocab=False)),
|
||||
('meta.json', lambda p: p.open('w').write(json_dumps(self.meta)))
|
||||
))
|
||||
for proc in self.pipeline:
|
||||
if not hasattr(proc, 'name'):
|
||||
continue
|
||||
if proc.name in disable:
|
||||
continue
|
||||
if not hasattr(proc, 'to_disk'):
|
||||
continue
|
||||
serializers[proc.name] = lambda p, proc=proc: proc.to_disk(p, vocab=False)
|
||||
util.to_disk(path, serializers, {p: False for p in disable})
|
||||
|
||||
def from_disk(self, path, disable=[]):
|
||||
def from_disk(self, path, disable=tuple()):
|
||||
"""Loads state from a directory. Modifies the object in place and
|
||||
returns it. If the saved `Language` object contains a model, the
|
||||
model will be loaded.
|
||||
|
@ -403,24 +408,21 @@ class Language(object):
|
|||
>>> nlp = Language().from_disk('/path/to/models')
|
||||
"""
|
||||
path = util.ensure_path(path)
|
||||
with path.open('rb') as file_:
|
||||
bytes_data = file_.read()
|
||||
return self.from_bytes(bytes_data, disable)
|
||||
#deserializers = {
|
||||
# 'vocab': lambda p: self.vocab.from_disk(p),
|
||||
# 'tokenizer': lambda p: self.tokenizer.from_disk(p, vocab=False),
|
||||
# 'meta.json': lambda p: ujson.dump(p.open('w'), self.meta)
|
||||
#}
|
||||
#for proc in self.pipeline:
|
||||
# if not hasattr(proc, 'name'):
|
||||
# continue
|
||||
# if proc.name in disable:
|
||||
# continue
|
||||
# if not hasattr(proc, 'to_disk'):
|
||||
# continue
|
||||
# deserializers[proc.name] = lambda p: proc.from_disk(p, vocab=False)
|
||||
#util.from_disk(deserializers, path)
|
||||
#return self
|
||||
deserializers = OrderedDict((
|
||||
('vocab', lambda p: self.vocab.from_disk(p)),
|
||||
('tokenizer', lambda p: self.tokenizer.from_disk(p, vocab=False)),
|
||||
('meta.json', lambda p: p.open('w').write(json_dumps(self.meta)))
|
||||
))
|
||||
for proc in self.pipeline:
|
||||
if not hasattr(proc, 'name'):
|
||||
continue
|
||||
if proc.name in disable:
|
||||
continue
|
||||
if not hasattr(proc, 'to_disk'):
|
||||
continue
|
||||
deserializers[proc.name] = lambda p, proc=proc: proc.from_disk(p, vocab=False)
|
||||
util.from_disk(path, deserializers, {p: False for p in disable})
|
||||
return self
|
||||
|
||||
def to_bytes(self, disable=[]):
|
||||
"""Serialize the current state to a binary string.
|
||||
|
|
|
@ -41,7 +41,7 @@ from .parts_of_speech import X
|
|||
|
||||
class TokenVectorEncoder(object):
|
||||
"""Assign position-sensitive vectors to tokens, using a CNN or RNN."""
|
||||
name = 'tok2vec'
|
||||
name = 'tensorizer'
|
||||
|
||||
@classmethod
|
||||
def Model(cls, width=128, embed_size=7500, **cfg):
|
||||
|
@ -176,17 +176,19 @@ class TokenVectorEncoder(object):
|
|||
return self
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
serialize = {
|
||||
'model': lambda p: p.open('w').write(util.model_to_bytes(self.model)),
|
||||
'vocab': lambda p: self.vocab.to_disk(p)
|
||||
}
|
||||
serialize = OrderedDict((
|
||||
('model', lambda p: p.open('wb').write(util.model_to_bytes(self.model))),
|
||||
('vocab', lambda p: self.vocab.to_disk(p))
|
||||
))
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, **exclude):
|
||||
deserialize = {
|
||||
'model': lambda p: util.model_from_bytes(self.model, p.open('rb').read()),
|
||||
'vocab': lambda p: self.vocab.from_disk(p)
|
||||
}
|
||||
if self.model is True:
|
||||
self.model = self.Model()
|
||||
deserialize = OrderedDict((
|
||||
('model', lambda p: util.model_from_bytes(self.model, p.open('rb').read())),
|
||||
('vocab', lambda p: self.vocab.from_disk(p))
|
||||
))
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
return self
|
||||
|
||||
|
@ -315,7 +317,7 @@ class NeuralTagger(object):
|
|||
|
||||
def to_disk(self, path, **exclude):
|
||||
serialize = {
|
||||
'model': lambda p: p.open('w').write(util.model_to_bytes(self.model)),
|
||||
'model': lambda p: p.open('wb').write(util.model_to_bytes(self.model)),
|
||||
'vocab': lambda p: self.vocab.to_disk(p)
|
||||
}
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
@ -420,7 +422,7 @@ cdef class NeuralDependencyParser(NeuralParser):
|
|||
|
||||
|
||||
cdef class NeuralEntityRecognizer(NeuralParser):
|
||||
name = 'entity'
|
||||
name = 'ner'
|
||||
TransitionSystem = BiluoPushDown
|
||||
|
||||
nr_feature = 6
|
||||
|
|
|
@ -44,6 +44,7 @@ from .. import util
|
|||
from ..util import get_async, get_cuda_stream
|
||||
from .._ml import zero_init, PrecomputableAffine, PrecomputableMaxouts
|
||||
from .._ml import Tok2Vec, doc2feats, rebatch
|
||||
from ..compat import json_dumps
|
||||
|
||||
from . import _parse_features
|
||||
from ._parse_features cimport CONTEXT_SIZE
|
||||
|
@ -633,11 +634,13 @@ cdef class Parser:
|
|||
|
||||
def to_disk(self, path, **exclude):
|
||||
serializers = {
|
||||
'model': lambda p: p.open('wb').write(
|
||||
util.model_to_bytes(self.model)),
|
||||
'lower_model': lambda p: p.open('wb').write(
|
||||
util.model_to_bytes(self.model[0])),
|
||||
'upper_model': lambda p: p.open('wb').write(
|
||||
util.model_to_bytes(self.model[1])),
|
||||
'vocab': lambda p: self.vocab.to_disk(p),
|
||||
'moves': lambda p: self.moves.to_disk(p, strings=False),
|
||||
'cfg': lambda p: ujson.dumps(p.open('w'), self.cfg)
|
||||
'cfg': lambda p: p.open('w').write(json_dumps(self.cfg))
|
||||
}
|
||||
util.to_disk(path, serializers, exclude)
|
||||
|
||||
|
@ -645,7 +648,7 @@ cdef class Parser:
|
|||
deserializers = {
|
||||
'vocab': lambda p: self.vocab.from_disk(p),
|
||||
'moves': lambda p: self.moves.from_disk(p, strings=False),
|
||||
'cfg': lambda p: self.cfg.update(ujson.load((path/'cfg.json').open())),
|
||||
'cfg': lambda p: self.cfg.update(ujson.load(p.open())),
|
||||
'model': lambda p: None
|
||||
}
|
||||
util.from_disk(path, deserializers, exclude)
|
||||
|
@ -653,7 +656,14 @@ cdef class Parser:
|
|||
path = util.ensure_path(path)
|
||||
if self.model is True:
|
||||
self.model, cfg = self.Model(**self.cfg)
|
||||
util.model_from_disk(self.model, path / 'model')
|
||||
else:
|
||||
cfg = {}
|
||||
with (path / 'lower_model').open('rb') as file_:
|
||||
bytes_data = file_.read()
|
||||
util.model_from_bytes(self.model[0], bytes_data)
|
||||
with (path / 'upper_model').open('rb') as file_:
|
||||
bytes_data = file_.read()
|
||||
util.model_from_bytes(self.model[1], bytes_data)
|
||||
self.cfg.update(cfg)
|
||||
return self
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import random
|
|||
import numpy
|
||||
import io
|
||||
import dill
|
||||
from collections import OrderedDict
|
||||
|
||||
import msgpack
|
||||
import msgpack_numpy
|
||||
|
@ -408,7 +409,7 @@ def get_raw_input(description, default=False):
|
|||
|
||||
|
||||
def to_bytes(getters, exclude):
|
||||
serialized = {}
|
||||
serialized = OrderedDict()
|
||||
for key, getter in getters.items():
|
||||
if key not in exclude:
|
||||
serialized[key] = getter()
|
||||
|
@ -423,6 +424,24 @@ def from_bytes(bytes_data, setters, exclude):
|
|||
return msg
|
||||
|
||||
|
||||
def to_disk(path, writers, exclude):
|
||||
path = ensure_path(path)
|
||||
if not path.exists():
|
||||
path.mkdir()
|
||||
for key, writer in writers.items():
|
||||
if key not in exclude:
|
||||
writer(path / key)
|
||||
return path
|
||||
|
||||
|
||||
def from_disk(path, readers, exclude):
|
||||
path = ensure_path(path)
|
||||
for key, reader in readers.items():
|
||||
if key not in exclude:
|
||||
reader(path / key)
|
||||
return path
|
||||
|
||||
|
||||
# This stuff really belongs in thinc -- but I expect
|
||||
# to refactor how all this works in thinc anyway.
|
||||
# What a mess!
|
||||
|
|
Loading…
Reference in New Issue
Block a user