Fix vectors deserialization

This commit is contained in:
Matthew Honnibal 2017-08-19 04:33:03 +02:00
parent 42d47c1e5c
commit 19c495f451

View File

@ -20,7 +20,7 @@ cdef class Vectors:
'''Store, save and load word vectors.''' '''Store, save and load word vectors.'''
cdef public object data cdef public object data
cdef readonly StringStore strings cdef readonly StringStore strings
cdef public object index cdef public object key2row
def __init__(self, strings, data_or_width): def __init__(self, strings, data_or_width):
self.strings = StringStore() self.strings = StringStore()
@ -30,9 +30,9 @@ cdef class Vectors:
else: else:
data = data_or_width data = data_or_width
self.data = data self.data = data
self.index = {} self.key2row = {}
for i, string in enumerate(strings): for i, string in enumerate(strings):
self.index[self.strings.add(string)] = i self.key2row[self.strings.add(string)] = i
def __reduce__(self): def __reduce__(self):
return (Vectors, (self.strings, self.data)) return (Vectors, (self.strings, self.data))
@ -40,7 +40,7 @@ cdef class Vectors:
def __getitem__(self, key): def __getitem__(self, key):
if isinstance(key, basestring): if isinstance(key, basestring):
key = self.strings[key] key = self.strings[key]
i = self.index[key] i = self.key2row[key]
if i is None: if i is None:
raise KeyError(key) raise KeyError(key)
else: else:
@ -49,7 +49,7 @@ cdef class Vectors:
def __setitem__(self, key, vector): def __setitem__(self, key, vector):
if isinstance(key, basestring): if isinstance(key, basestring):
key = self.strings.add(key) key = self.strings.add(key)
i = self.index[key] i = self.key2row[key]
self.data[i] = vector self.data[i] = vector
def __iter__(self): def __iter__(self):
@ -71,7 +71,7 @@ cdef class Vectors:
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
def serialize_vectors(p): def serialize_vectors(p):
write_vectors_to_bin_loc(self.strings, self.key2i, self.data, str(p)) write_vectors_to_bin_loc(self.strings, self.key2row, self.data, str(p))
serializers = OrderedDict(( serializers = OrderedDict((
('vec.bin', serialize_vectors), ('vec.bin', serialize_vectors),
@ -80,12 +80,13 @@ cdef class Vectors:
def from_disk(self, path, **exclude): def from_disk(self, path, **exclude):
def deserialize_vectors(p): def deserialize_vectors(p):
self.key2i, self.vectors = load_vectors_from_bin_loc(self.strings, str(p)) values = load_vectors_from_bin_loc(self.strings, str(p))
self.key2row, self.data = values
serializers = OrderedDict(( serializers = OrderedDict((
('vec.bin', deserialize_vectors) ('vec.bin', deserialize_vectors),
)) ))
return util.to_disk(serializers, exclude) return util.from_disk(path, serializers, exclude)
def to_bytes(self, **exclude): def to_bytes(self, **exclude):
def serialize_weights(): def serialize_weights():
@ -93,9 +94,9 @@ cdef class Vectors:
return self.data.to_bytes() return self.data.to_bytes()
else: else:
return msgpack.dumps(self.data) return msgpack.dumps(self.data)
b = msgpack.dumps(self.key2row)
serializers = OrderedDict(( serializers = OrderedDict((
('key2row', lambda: msgpack.dumps(self.key2i)), ('key2row', lambda: msgpack.dumps(self.key2row)),
('strings', lambda: self.strings.to_bytes()), ('strings', lambda: self.strings.to_bytes()),
('vectors', serialize_weights) ('vectors', serialize_weights)
)) ))
@ -109,7 +110,7 @@ cdef class Vectors:
self.data = msgpack.loads(b) self.data = msgpack.loads(b)
deserializers = OrderedDict(( deserializers = OrderedDict((
('key2row', lambda b: self.key2i.update(msgpack.loads(b))), ('key2row', lambda b: self.key2row.update(msgpack.loads(b))),
('strings', lambda b: self.strings.from_bytes(b)), ('strings', lambda b: self.strings.from_bytes(b)),
('vectors', deserialize_weights) ('vectors', deserialize_weights)
)) ))