mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Fix serialization
This commit is contained in:
parent
1157294434
commit
6a94648373
|
@ -1,3 +1,4 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
from libc.stdint cimport int32_t, uint64_t
|
from libc.stdint cimport int32_t, uint64_t
|
||||||
import numpy
|
import numpy
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
@ -32,7 +33,7 @@ cdef class Vectors:
|
||||||
self.key2row = {}
|
self.key2row = {}
|
||||||
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
|
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
|
||||||
for string in strings:
|
for string in strings:
|
||||||
self.add_key(string)
|
self.add(string)
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (Vectors, (self.strings, self.data))
|
return (Vectors, (self.strings, self.data))
|
||||||
|
@ -94,7 +95,6 @@ cdef class Vectors:
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
|
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
|
||||||
('strings.json', self.strings.to_disk),
|
|
||||||
('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)),
|
('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)),
|
||||||
))
|
))
|
||||||
return util.to_disk(path, serializers, exclude)
|
return util.to_disk(path, serializers, exclude)
|
||||||
|
@ -112,7 +112,6 @@ cdef class Vectors:
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('keys', load_keys),
|
('keys', load_keys),
|
||||||
('vectors', load_vectors),
|
('vectors', load_vectors),
|
||||||
('strings.json', self.strings.from_disk),
|
|
||||||
))
|
))
|
||||||
util.from_disk(path, serializers, exclude)
|
util.from_disk(path, serializers, exclude)
|
||||||
return self
|
return self
|
||||||
|
@ -125,7 +124,6 @@ cdef class Vectors:
|
||||||
return msgpack.dumps(self.data)
|
return msgpack.dumps(self.data)
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('keys', lambda: msgpack.dumps(self.keys)),
|
('keys', lambda: msgpack.dumps(self.keys)),
|
||||||
('strings', lambda: self.strings.to_bytes()),
|
|
||||||
('vectors', serialize_weights)
|
('vectors', serialize_weights)
|
||||||
))
|
))
|
||||||
return util.to_bytes(serializers, exclude)
|
return util.to_bytes(serializers, exclude)
|
||||||
|
@ -138,13 +136,13 @@ cdef class Vectors:
|
||||||
self.data = msgpack.loads(b)
|
self.data = msgpack.loads(b)
|
||||||
|
|
||||||
def load_keys(keys):
|
def load_keys(keys):
|
||||||
|
self.keys.resize((len(keys),))
|
||||||
for i, key in enumerate(keys):
|
for i, key in enumerate(keys):
|
||||||
self.keys[i] = key
|
self.keys[i] = key
|
||||||
self.key2row[key] = i
|
self.key2row[key] = i
|
||||||
|
|
||||||
deserializers = OrderedDict((
|
deserializers = OrderedDict((
|
||||||
('keys', lambda b: load_keys(msgpack.loads(b))),
|
('keys', lambda b: load_keys(msgpack.loads(b))),
|
||||||
('strings', lambda b: self.strings.from_bytes(b)),
|
|
||||||
('vectors', deserialize_weights)
|
('vectors', deserialize_weights)
|
||||||
))
|
))
|
||||||
util.from_bytes(data, deserializers, exclude)
|
util.from_bytes(data, deserializers, exclude)
|
||||||
|
|
|
@ -303,7 +303,7 @@ cdef class Vocab:
|
||||||
with (path / 'lexemes.bin').open('wb') as file_:
|
with (path / 'lexemes.bin').open('wb') as file_:
|
||||||
file_.write(self.lexemes_to_bytes())
|
file_.write(self.lexemes_to_bytes())
|
||||||
if self.vectors is not None:
|
if self.vectors is not None:
|
||||||
self.vectors.to_disk(path, exclude='strings.json')
|
self.vectors.to_disk(path)
|
||||||
|
|
||||||
def from_disk(self, path, **exclude):
|
def from_disk(self, path, **exclude):
|
||||||
"""Loads state from a directory. Modifies the object in place and
|
"""Loads state from a directory. Modifies the object in place and
|
||||||
|
@ -318,7 +318,7 @@ cdef class Vocab:
|
||||||
with (path / 'lexemes.bin').open('rb') as file_:
|
with (path / 'lexemes.bin').open('rb') as file_:
|
||||||
self.lexemes_from_bytes(file_.read())
|
self.lexemes_from_bytes(file_.read())
|
||||||
if self.vectors is not None:
|
if self.vectors is not None:
|
||||||
self.vectors.from_disk(path, exclude='string.json')
|
self.vectors.from_disk(path, exclude='strings.json')
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_bytes(self, **exclude):
|
def to_bytes(self, **exclude):
|
||||||
|
@ -331,7 +331,7 @@ cdef class Vocab:
|
||||||
if self.vectors is None:
|
if self.vectors is None:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return self.vectors.to_bytes(exclude='strings')
|
return self.vectors.to_bytes(exclude='strings.json')
|
||||||
|
|
||||||
getters = OrderedDict((
|
getters = OrderedDict((
|
||||||
('strings', lambda: self.strings.to_bytes()),
|
('strings', lambda: self.strings.to_bytes()),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user