mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Fix vectors deserialization
This commit is contained in:
		
							parent
							
								
									42d47c1e5c
								
							
						
					
					
						commit
						19c495f451
					
				| 
						 | 
				
			
			@ -20,7 +20,7 @@ cdef class Vectors:
 | 
			
		|||
    '''Store, save and load word vectors.'''
 | 
			
		||||
    cdef public object data
 | 
			
		||||
    cdef readonly StringStore strings
 | 
			
		||||
    cdef public object index
 | 
			
		||||
    cdef public object key2row
 | 
			
		||||
 | 
			
		||||
    def __init__(self, strings, data_or_width):
 | 
			
		||||
        self.strings = StringStore()
 | 
			
		||||
| 
						 | 
				
			
			@ -30,9 +30,9 @@ cdef class Vectors:
 | 
			
		|||
        else:
 | 
			
		||||
            data = data_or_width
 | 
			
		||||
        self.data = data
 | 
			
		||||
        self.index = {}
 | 
			
		||||
        self.key2row = {}
 | 
			
		||||
        for i, string in enumerate(strings):
 | 
			
		||||
            self.index[self.strings.add(string)] = i
 | 
			
		||||
            self.key2row[self.strings.add(string)] = i
 | 
			
		||||
 | 
			
		||||
    def __reduce__(self):
 | 
			
		||||
        return (Vectors, (self.strings, self.data))
 | 
			
		||||
| 
						 | 
				
			
			@ -40,7 +40,7 @@ cdef class Vectors:
 | 
			
		|||
    def __getitem__(self, key):
 | 
			
		||||
        if isinstance(key, basestring):
 | 
			
		||||
            key = self.strings[key]
 | 
			
		||||
        i = self.index[key]
 | 
			
		||||
        i = self.key2row[key]
 | 
			
		||||
        if i is None:
 | 
			
		||||
            raise KeyError(key)
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -49,7 +49,7 @@ cdef class Vectors:
 | 
			
		|||
    def __setitem__(self, key, vector):
 | 
			
		||||
        if isinstance(key, basestring):
 | 
			
		||||
            key = self.strings.add(key)
 | 
			
		||||
        i = self.index[key]
 | 
			
		||||
        i = self.key2row[key]
 | 
			
		||||
        self.data[i] = vector
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -71,7 +71,7 @@ cdef class Vectors:
 | 
			
		|||
 | 
			
		||||
    def to_disk(self, path, **exclude):
 | 
			
		||||
        def serialize_vectors(p):
 | 
			
		||||
            write_vectors_to_bin_loc(self.strings, self.key2i, self.data, str(p))
 | 
			
		||||
            write_vectors_to_bin_loc(self.strings, self.key2row, self.data, str(p))
 | 
			
		||||
 | 
			
		||||
        serializers = OrderedDict((
 | 
			
		||||
            ('vec.bin', serialize_vectors),
 | 
			
		||||
| 
						 | 
				
			
			@ -80,12 +80,13 @@ cdef class Vectors:
 | 
			
		|||
 | 
			
		||||
    def from_disk(self, path, **exclude):
 | 
			
		||||
        def deserialize_vectors(p):
 | 
			
		||||
            self.key2i, self.vectors = load_vectors_from_bin_loc(self.strings, str(p))
 | 
			
		||||
            values = load_vectors_from_bin_loc(self.strings, str(p))
 | 
			
		||||
            self.key2row, self.data = values
 | 
			
		||||
 | 
			
		||||
        serializers = OrderedDict((
 | 
			
		||||
            ('vec.bin', deserialize_vectors)
 | 
			
		||||
            ('vec.bin', deserialize_vectors),
 | 
			
		||||
        ))
 | 
			
		||||
        return util.to_disk(serializers, exclude)
 | 
			
		||||
        return util.from_disk(path, serializers, exclude)
 | 
			
		||||
 | 
			
		||||
    def to_bytes(self, **exclude):
 | 
			
		||||
        def serialize_weights():
 | 
			
		||||
| 
						 | 
				
			
			@ -93,9 +94,9 @@ cdef class Vectors:
 | 
			
		|||
                return self.data.to_bytes()
 | 
			
		||||
            else:
 | 
			
		||||
                return msgpack.dumps(self.data)
 | 
			
		||||
 | 
			
		||||
        b = msgpack.dumps(self.key2row)
 | 
			
		||||
        serializers = OrderedDict((
 | 
			
		||||
            ('key2row', lambda: msgpack.dumps(self.key2i)),
 | 
			
		||||
            ('key2row', lambda: msgpack.dumps(self.key2row)),
 | 
			
		||||
            ('strings', lambda: self.strings.to_bytes()),
 | 
			
		||||
            ('vectors', serialize_weights)
 | 
			
		||||
        ))
 | 
			
		||||
| 
						 | 
				
			
			@ -109,7 +110,7 @@ cdef class Vectors:
 | 
			
		|||
                self.data = msgpack.loads(b)
 | 
			
		||||
 | 
			
		||||
        deserializers = OrderedDict((
 | 
			
		||||
            ('key2row', lambda b: self.key2i.update(msgpack.loads(b))),
 | 
			
		||||
            ('key2row', lambda b: self.key2row.update(msgpack.loads(b))),
 | 
			
		||||
            ('strings', lambda b: self.strings.from_bytes(b)),
 | 
			
		||||
            ('vectors', deserialize_weights)
 | 
			
		||||
        ))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user