Fix serialization

This commit is contained in:
Matthew Honnibal 2017-08-19 21:27:35 +02:00
parent 1157294434
commit 6a94648373
2 changed files with 6 additions and 8 deletions

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
from libc.stdint cimport int32_t, uint64_t from libc.stdint cimport int32_t, uint64_t
import numpy import numpy
from collections import OrderedDict from collections import OrderedDict
@ -32,7 +33,7 @@ cdef class Vectors:
self.key2row = {} self.key2row = {}
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64') self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
for string in strings: for string in strings:
self.add_key(string) self.add(string)
def __reduce__(self): def __reduce__(self):
return (Vectors, (self.strings, self.data)) return (Vectors, (self.strings, self.data))
@ -94,7 +95,6 @@ cdef class Vectors:
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
serializers = OrderedDict(( serializers = OrderedDict((
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)), ('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
('strings.json', self.strings.to_disk),
('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)), ('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)),
)) ))
return util.to_disk(path, serializers, exclude) return util.to_disk(path, serializers, exclude)
@ -112,7 +112,6 @@ cdef class Vectors:
serializers = OrderedDict(( serializers = OrderedDict((
('keys', load_keys), ('keys', load_keys),
('vectors', load_vectors), ('vectors', load_vectors),
('strings.json', self.strings.from_disk),
)) ))
util.from_disk(path, serializers, exclude) util.from_disk(path, serializers, exclude)
return self return self
@ -125,7 +124,6 @@ cdef class Vectors:
return msgpack.dumps(self.data) return msgpack.dumps(self.data)
serializers = OrderedDict(( serializers = OrderedDict((
('keys', lambda: msgpack.dumps(self.keys)), ('keys', lambda: msgpack.dumps(self.keys)),
('strings', lambda: self.strings.to_bytes()),
('vectors', serialize_weights) ('vectors', serialize_weights)
)) ))
return util.to_bytes(serializers, exclude) return util.to_bytes(serializers, exclude)
@ -138,13 +136,13 @@ cdef class Vectors:
self.data = msgpack.loads(b) self.data = msgpack.loads(b)
def load_keys(keys): def load_keys(keys):
self.keys.resize((len(keys),))
for i, key in enumerate(keys): for i, key in enumerate(keys):
self.keys[i] = key self.keys[i] = key
self.key2row[key] = i self.key2row[key] = i
deserializers = OrderedDict(( deserializers = OrderedDict((
('keys', lambda b: load_keys(msgpack.loads(b))), ('keys', lambda b: load_keys(msgpack.loads(b))),
('strings', lambda b: self.strings.from_bytes(b)),
('vectors', deserialize_weights) ('vectors', deserialize_weights)
)) ))
util.from_bytes(data, deserializers, exclude) util.from_bytes(data, deserializers, exclude)

View File

@ -303,7 +303,7 @@ cdef class Vocab:
with (path / 'lexemes.bin').open('wb') as file_: with (path / 'lexemes.bin').open('wb') as file_:
file_.write(self.lexemes_to_bytes()) file_.write(self.lexemes_to_bytes())
if self.vectors is not None: if self.vectors is not None:
self.vectors.to_disk(path, exclude='strings.json') self.vectors.to_disk(path)
def from_disk(self, path, **exclude): def from_disk(self, path, **exclude):
"""Loads state from a directory. Modifies the object in place and """Loads state from a directory. Modifies the object in place and
@ -318,7 +318,7 @@ cdef class Vocab:
with (path / 'lexemes.bin').open('rb') as file_: with (path / 'lexemes.bin').open('rb') as file_:
self.lexemes_from_bytes(file_.read()) self.lexemes_from_bytes(file_.read())
if self.vectors is not None: if self.vectors is not None:
self.vectors.from_disk(path, exclude='string.json') self.vectors.from_disk(path, exclude='strings.json')
return self return self
def to_bytes(self, **exclude): def to_bytes(self, **exclude):
@ -331,7 +331,7 @@ cdef class Vocab:
if self.vectors is None: if self.vectors is None:
return None return None
else: else:
return self.vectors.to_bytes(exclude='strings') return self.vectors.to_bytes(exclude='strings.json')
getters = OrderedDict(( getters = OrderedDict((
('strings', lambda: self.strings.to_bytes()), ('strings', lambda: self.strings.to_bytes()),