mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Improve vectors to/from disk
This commit is contained in:
parent
3a3ed43e0c
commit
3d049af563
|
@ -21,6 +21,7 @@ cdef class Vectors:
|
|||
cdef public object data
|
||||
cdef readonly StringStore strings
|
||||
cdef public object key2row
|
||||
cdef public object keys
|
||||
|
||||
def __init__(self, strings, data_or_width):
|
||||
self.strings = StringStore()
|
||||
|
@ -31,8 +32,11 @@ cdef class Vectors:
|
|||
data = data_or_width
|
||||
self.data = data
|
||||
self.key2row = {}
|
||||
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
|
||||
for i, string in enumerate(strings):
|
||||
self.key2row[self.strings.add(string)] = i
|
||||
key = self.strings.add(string)
|
||||
self.key2row[key] = i
|
||||
self.keys[i] = key
|
||||
|
||||
def __reduce__(self):
|
||||
return (Vectors, (self.strings, self.data))
|
||||
|
@ -70,23 +74,30 @@ cdef class Vectors:
|
|||
raise NotImplementedError
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
def serialize_vectors(p):
|
||||
write_vectors_to_bin_loc(self.strings, self.key2row, self.data, str(p))
|
||||
|
||||
serializers = OrderedDict((
|
||||
('vec.bin', serialize_vectors),
|
||||
('vectors', lambda p: numpy.save(p.open('wb'), self.data)),
|
||||
('strings.json', self.strings.to_disk),
|
||||
('keys', lambda p: numpy.save(p.open('wb'), self.keys)),
|
||||
))
|
||||
return util.to_disk(serializers, exclude)
|
||||
return util.to_disk(path, serializers, exclude)
|
||||
|
||||
def from_disk(self, path, **exclude):
|
||||
def deserialize_vectors(p):
|
||||
values = load_vectors_from_bin_loc(self.strings, str(p))
|
||||
self.key2row, self.data = values
|
||||
def load_keys(path):
|
||||
self.keys = numpy.load(path)
|
||||
for i, key in enumerate(self.keys):
|
||||
self.keys[i] = key
|
||||
self.key2row[key] = i
|
||||
|
||||
def load_vectors(path):
|
||||
self.data = numpy.load(path)
|
||||
|
||||
serializers = OrderedDict((
|
||||
('vec.bin', deserialize_vectors),
|
||||
('keys', load_keys),
|
||||
('vectors', load_vectors),
|
||||
('strings.json', self.strings.from_disk),
|
||||
))
|
||||
return util.from_disk(path, serializers, exclude)
|
||||
util.from_disk(path, serializers, exclude)
|
||||
return self
|
||||
|
||||
def to_bytes(self, **exclude):
|
||||
def serialize_weights():
|
||||
|
@ -94,9 +105,8 @@ cdef class Vectors:
|
|||
return self.data.to_bytes()
|
||||
else:
|
||||
return msgpack.dumps(self.data)
|
||||
b = msgpack.dumps(self.key2row)
|
||||
serializers = OrderedDict((
|
||||
('key2row', lambda: msgpack.dumps(self.key2row)),
|
||||
('keys', lambda: msgpack.dumps(self.keys)),
|
||||
('strings', lambda: self.strings.to_bytes()),
|
||||
('vectors', serialize_weights)
|
||||
))
|
||||
|
@ -109,80 +119,15 @@ cdef class Vectors:
|
|||
else:
|
||||
self.data = msgpack.loads(b)
|
||||
|
||||
def load_keys(keys):
|
||||
for i, key in enumerate(keys):
|
||||
self.keys[i] = key
|
||||
self.key2row[key] = i
|
||||
|
||||
deserializers = OrderedDict((
|
||||
('key2row', lambda b: self.key2row.update(msgpack.loads(b))),
|
||||
('keys', lambda b: load_keys(msgpack.loads(b))),
|
||||
('strings', lambda b: self.strings.from_bytes(b)),
|
||||
('vectors', deserialize_weights)
|
||||
))
|
||||
return util.from_bytes(deserializers, exclude)
|
||||
|
||||
|
||||
def write_vectors_to_bin_loc(StringStore strings, dict key2i,
|
||||
np.ndarray vectors, out_loc):
|
||||
|
||||
cdef int32_t vec_len = vectors.shape[1]
|
||||
cdef int32_t word_len
|
||||
cdef bytes word_str
|
||||
cdef char* chars
|
||||
cdef uint64_t key
|
||||
cdef int32_t i
|
||||
cdef float* vec
|
||||
|
||||
cdef CFile out_file = CFile(out_loc, 'wb')
|
||||
keys = [(i, key) for (key, i) in key2i.item()]
|
||||
keys.sort()
|
||||
for i, key in keys:
|
||||
vec = <float*>vectors.data[i * vec_len]
|
||||
word_str = strings[key].encode('utf8')
|
||||
word_len = len(word_str)
|
||||
|
||||
out_file.write_from(&word_len, 1, sizeof(word_len))
|
||||
out_file.write_from(&vec_len, 1, sizeof(vec_len))
|
||||
|
||||
chars = <char*>word_str
|
||||
out_file.write_from(chars, word_len, sizeof(char))
|
||||
out_file.write_from(vec, vec_len, sizeof(float))
|
||||
out_file.close()
|
||||
|
||||
|
||||
def load_vectors_from_bin_loc(StringStore strings, loc):
|
||||
"""
|
||||
Load vectors from the location of a binary file.
|
||||
Arguments:
|
||||
loc (unicode): The path of the binary file to load from.
|
||||
Returns:
|
||||
vec_len (int): The length of the vectors loaded.
|
||||
"""
|
||||
cdef CFile file_ = CFile(loc, b'rb')
|
||||
cdef int32_t word_len
|
||||
cdef int32_t vec_len = 0
|
||||
cdef int32_t prev_vec_len = 0
|
||||
cdef float* vec
|
||||
cdef attr_t string_id
|
||||
cdef bytes py_word
|
||||
cdef vector[float*] vectors
|
||||
cdef int line_num = 0
|
||||
cdef Pool mem = Pool()
|
||||
cdef dict key2i = {}
|
||||
while True:
|
||||
try:
|
||||
file_.read_into(&word_len, sizeof(word_len), 1)
|
||||
except IOError:
|
||||
break
|
||||
file_.read_into(&vec_len, sizeof(vec_len), 1)
|
||||
if prev_vec_len != 0 and vec_len != prev_vec_len:
|
||||
raise Exception("Mismatched vector sizes")
|
||||
if 0 >= vec_len >= MAX_VEC_SIZE:
|
||||
raise Exception("Mismatched vector sizes")
|
||||
|
||||
chars = <char*>file_.alloc_read(mem, word_len, sizeof(char))
|
||||
vec = <float*>file_.alloc_read(mem, vec_len, sizeof(float))
|
||||
|
||||
key = strings.add(chars[:word_len])
|
||||
key2i[key] = vectors.size()
|
||||
vectors.push_back(vec)
|
||||
numpy_vectors = numpy.zeros((vectors.size(), vec_len), dtype='f')
|
||||
for i in range(vectors.size()):
|
||||
for j in range(vec_len):
|
||||
numpy_vectors[i, j] = vectors[i][j]
|
||||
return key2i, numpy_vectors
|
||||
util.from_bytes(deserializers, exclude)
|
||||
return self
|
||||
|
|
Loading…
Reference in New Issue
Block a user