From 6a9464837381f2aad57e57b319acf27509174c79 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 19 Aug 2017 21:27:35 +0200 Subject: [PATCH] Fix serialization --- spacy/vectors.pyx | 8 +++----- spacy/vocab.pyx | 6 +++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index b6ddf0818..ae45f99b3 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -1,3 +1,4 @@ +from __future__ import unicode_literals from libc.stdint cimport int32_t, uint64_t import numpy from collections import OrderedDict @@ -32,7 +33,7 @@ cdef class Vectors: self.key2row = {} self.keys = np.ndarray((self.data.shape[0],), dtype='uint64') for string in strings: - self.add_key(string) + self.add(string) def __reduce__(self): return (Vectors, (self.strings, self.data)) @@ -94,7 +95,6 @@ cdef class Vectors: def to_disk(self, path, **exclude): serializers = OrderedDict(( ('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)), - ('strings.json', self.strings.to_disk), ('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)), )) return util.to_disk(path, serializers, exclude) @@ -112,7 +112,6 @@ cdef class Vectors: serializers = OrderedDict(( ('keys', load_keys), ('vectors', load_vectors), - ('strings.json', self.strings.from_disk), )) util.from_disk(path, serializers, exclude) return self @@ -125,7 +124,6 @@ cdef class Vectors: return msgpack.dumps(self.data) serializers = OrderedDict(( ('keys', lambda: msgpack.dumps(self.keys)), - ('strings', lambda: self.strings.to_bytes()), ('vectors', serialize_weights) )) return util.to_bytes(serializers, exclude) @@ -138,13 +136,13 @@ cdef class Vectors: self.data = msgpack.loads(b) def load_keys(keys): + self.keys.resize((len(keys),)) for i, key in enumerate(keys): self.keys[i] = key self.key2row[key] = i deserializers = OrderedDict(( ('keys', lambda b: load_keys(msgpack.loads(b))), - ('strings', lambda b: self.strings.from_bytes(b)), ('vectors', deserialize_weights) )) util.from_bytes(data, deserializers, exclude) diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 1c992b56c..dc141552d 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -303,7 +303,7 @@ cdef class Vocab: with (path / 'lexemes.bin').open('wb') as file_: file_.write(self.lexemes_to_bytes()) if self.vectors is not None: - self.vectors.to_disk(path, exclude='strings.json') + self.vectors.to_disk(path) def from_disk(self, path, **exclude): """Loads state from a directory. Modifies the object in place and @@ -318,7 +318,7 @@ cdef class Vocab: with (path / 'lexemes.bin').open('rb') as file_: self.lexemes_from_bytes(file_.read()) if self.vectors is not None: - self.vectors.from_disk(path, exclude='string.json') + self.vectors.from_disk(path, exclude='strings.json') return self def to_bytes(self, **exclude): @@ -331,7 +331,7 @@ cdef class Vocab: if self.vectors is None: return None else: - return self.vectors.to_bytes(exclude='strings') + return self.vectors.to_bytes(exclude='strings.json') getters = OrderedDict(( ('strings', lambda: self.strings.to_bytes()),