From 19c495f451e3b83f5575743d63c3745a9fd5eaa2 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 19 Aug 2017 04:33:03 +0200 Subject: [PATCH] Fix vectors deserialization --- spacy/vectors.pyx | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index 59a24dfa9..1b1e8000a 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -20,7 +20,7 @@ cdef class Vectors: '''Store, save and load word vectors.''' cdef public object data cdef readonly StringStore strings - cdef public object index + cdef public object key2row def __init__(self, strings, data_or_width): self.strings = StringStore() @@ -30,9 +30,9 @@ cdef class Vectors: else: data = data_or_width self.data = data - self.index = {} + self.key2row = {} for i, string in enumerate(strings): - self.index[self.strings.add(string)] = i + self.key2row[self.strings.add(string)] = i def __reduce__(self): return (Vectors, (self.strings, self.data)) @@ -40,7 +40,7 @@ cdef class Vectors: def __getitem__(self, key): if isinstance(key, basestring): key = self.strings[key] - i = self.index[key] + i = self.key2row[key] if i is None: raise KeyError(key) else: @@ -49,7 +49,7 @@ cdef class Vectors: def __setitem__(self, key, vector): if isinstance(key, basestring): key = self.strings.add(key) - i = self.index[key] + i = self.key2row[key] self.data[i] = vector def __iter__(self): @@ -71,7 +71,7 @@ cdef class Vectors: def to_disk(self, path, **exclude): def serialize_vectors(p): - write_vectors_to_bin_loc(self.strings, self.key2i, self.data, str(p)) + write_vectors_to_bin_loc(self.strings, self.key2row, self.data, str(p)) serializers = OrderedDict(( ('vec.bin', serialize_vectors), @@ -80,12 +80,13 @@ cdef class Vectors: def from_disk(self, path, **exclude): def deserialize_vectors(p): - self.key2i, self.vectors = load_vectors_from_bin_loc(self.strings, str(p)) + values = load_vectors_from_bin_loc(self.strings, str(p)) + self.key2row, self.data = values serializers = OrderedDict(( - ('vec.bin', deserialize_vectors) + ('vec.bin', deserialize_vectors), )) - return util.to_disk(serializers, exclude) + return util.from_disk(path, serializers, exclude) def to_bytes(self, **exclude): def serialize_weights(): @@ -93,9 +94,9 @@ cdef class Vectors: return self.data.to_bytes() else: return msgpack.dumps(self.data) - + b = msgpack.dumps(self.key2row) serializers = OrderedDict(( - ('key2row', lambda: msgpack.dumps(self.key2i)), + ('key2row', lambda: msgpack.dumps(self.key2row)), ('strings', lambda: self.strings.to_bytes()), ('vectors', serialize_weights) )) @@ -109,7 +110,7 @@ cdef class Vectors: self.data = msgpack.loads(b) deserializers = OrderedDict(( - ('key2row', lambda b: self.key2i.update(msgpack.loads(b))), + ('key2row', lambda b: self.key2row.update(msgpack.loads(b))), ('strings', lambda b: self.strings.from_bytes(b)), ('vectors', deserialize_weights) ))