mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Fix serialization
This commit is contained in:
		
							parent
							
								
									1157294434
								
							
						
					
					
						commit
						6a94648373
					
				|  | @ -1,3 +1,4 @@ | ||||||
|  | from __future__ import unicode_literals | ||||||
| from libc.stdint cimport int32_t, uint64_t | from libc.stdint cimport int32_t, uint64_t | ||||||
| import numpy | import numpy | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
|  | @ -32,7 +33,7 @@ cdef class Vectors: | ||||||
|         self.key2row = {} |         self.key2row = {} | ||||||
|         self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')  |         self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')  | ||||||
|         for string in strings: |         for string in strings: | ||||||
|             self.add_key(string) |             self.add(string) | ||||||
| 
 | 
 | ||||||
|     def __reduce__(self): |     def __reduce__(self): | ||||||
|         return (Vectors, (self.strings, self.data)) |         return (Vectors, (self.strings, self.data)) | ||||||
|  | @ -94,7 +95,6 @@ cdef class Vectors: | ||||||
|     def to_disk(self, path, **exclude): |     def to_disk(self, path, **exclude): | ||||||
|         serializers = OrderedDict(( |         serializers = OrderedDict(( | ||||||
|             ('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)), |             ('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)), | ||||||
|             ('strings.json', self.strings.to_disk), |  | ||||||
|             ('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)), |             ('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)), | ||||||
|         )) |         )) | ||||||
|         return util.to_disk(path, serializers, exclude) |         return util.to_disk(path, serializers, exclude) | ||||||
|  | @ -112,7 +112,6 @@ cdef class Vectors: | ||||||
|         serializers = OrderedDict(( |         serializers = OrderedDict(( | ||||||
|             ('keys', load_keys), |             ('keys', load_keys), | ||||||
|             ('vectors', load_vectors), |             ('vectors', load_vectors), | ||||||
|             ('strings.json', self.strings.from_disk), |  | ||||||
|         )) |         )) | ||||||
|         util.from_disk(path, serializers, exclude) |         util.from_disk(path, serializers, exclude) | ||||||
|         return self |         return self | ||||||
|  | @ -125,7 +124,6 @@ cdef class Vectors: | ||||||
|                 return msgpack.dumps(self.data) |                 return msgpack.dumps(self.data) | ||||||
|         serializers = OrderedDict(( |         serializers = OrderedDict(( | ||||||
|             ('keys', lambda: msgpack.dumps(self.keys)), |             ('keys', lambda: msgpack.dumps(self.keys)), | ||||||
|             ('strings', lambda: self.strings.to_bytes()), |  | ||||||
|             ('vectors', serialize_weights) |             ('vectors', serialize_weights) | ||||||
|         )) |         )) | ||||||
|         return util.to_bytes(serializers, exclude) |         return util.to_bytes(serializers, exclude) | ||||||
|  | @ -138,13 +136,13 @@ cdef class Vectors: | ||||||
|                 self.data = msgpack.loads(b) |                 self.data = msgpack.loads(b) | ||||||
| 
 | 
 | ||||||
|         def load_keys(keys): |         def load_keys(keys): | ||||||
|  |             self.keys.resize((len(keys),)) | ||||||
|             for i, key in enumerate(keys): |             for i, key in enumerate(keys): | ||||||
|                 self.keys[i] = key |                 self.keys[i] = key | ||||||
|                 self.key2row[key] = i |                 self.key2row[key] = i | ||||||
| 
 | 
 | ||||||
|         deserializers = OrderedDict(( |         deserializers = OrderedDict(( | ||||||
|             ('keys', lambda b: load_keys(msgpack.loads(b))), |             ('keys', lambda b: load_keys(msgpack.loads(b))), | ||||||
|             ('strings', lambda b: self.strings.from_bytes(b)), |  | ||||||
|             ('vectors', deserialize_weights) |             ('vectors', deserialize_weights) | ||||||
|         )) |         )) | ||||||
|         util.from_bytes(data, deserializers, exclude) |         util.from_bytes(data, deserializers, exclude) | ||||||
|  |  | ||||||
|  | @ -303,7 +303,7 @@ cdef class Vocab: | ||||||
|         with (path / 'lexemes.bin').open('wb') as file_: |         with (path / 'lexemes.bin').open('wb') as file_: | ||||||
|             file_.write(self.lexemes_to_bytes()) |             file_.write(self.lexemes_to_bytes()) | ||||||
|         if self.vectors is not None: |         if self.vectors is not None: | ||||||
|             self.vectors.to_disk(path, exclude='strings.json') |             self.vectors.to_disk(path) | ||||||
| 
 | 
 | ||||||
|     def from_disk(self, path, **exclude): |     def from_disk(self, path, **exclude): | ||||||
|         """Loads state from a directory. Modifies the object in place and |         """Loads state from a directory. Modifies the object in place and | ||||||
|  | @ -318,7 +318,7 @@ cdef class Vocab: | ||||||
|         with (path / 'lexemes.bin').open('rb') as file_: |         with (path / 'lexemes.bin').open('rb') as file_: | ||||||
|             self.lexemes_from_bytes(file_.read()) |             self.lexemes_from_bytes(file_.read()) | ||||||
|         if self.vectors is not None: |         if self.vectors is not None: | ||||||
|             self.vectors.from_disk(path, exclude='string.json') |             self.vectors.from_disk(path, exclude='strings.json') | ||||||
|         return self |         return self | ||||||
| 
 | 
 | ||||||
|     def to_bytes(self, **exclude): |     def to_bytes(self, **exclude): | ||||||
|  | @ -331,7 +331,7 @@ cdef class Vocab: | ||||||
|             if self.vectors is None: |             if self.vectors is None: | ||||||
|                 return None |                 return None | ||||||
|             else: |             else: | ||||||
|                 return self.vectors.to_bytes(exclude='strings') |                 return self.vectors.to_bytes(exclude='strings.json') | ||||||
|   |   | ||||||
|         getters = OrderedDict(( |         getters = OrderedDict(( | ||||||
|             ('strings', lambda: self.strings.to_bytes()), |             ('strings', lambda: self.strings.to_bytes()), | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user