mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Work on vectors loading
This commit is contained in:
parent
52c6302223
commit
ed4fb991dc
|
@ -1,18 +1,26 @@
|
||||||
|
from libc.stdint cimport int32_t, uint64_t
|
||||||
import numpy
|
import numpy
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import msgpack
|
import msgpack
|
||||||
import msgpack_numpy
|
import msgpack_numpy
|
||||||
msgpack_numpy.patch()
|
msgpack_numpy.patch()
|
||||||
|
from cymem.cymem cimport Pool
|
||||||
|
cimport numpy as np
|
||||||
|
from libcpp.vector cimport vector
|
||||||
|
|
||||||
|
from .typedefs cimport attr_t
|
||||||
from .strings cimport StringStore
|
from .strings cimport StringStore
|
||||||
from . import util
|
from . import util
|
||||||
|
from ._cfile cimport CFile
|
||||||
|
|
||||||
|
MAX_VEC_SIZE = 10000
|
||||||
|
|
||||||
|
|
||||||
cdef class Vectors:
|
cdef class Vectors:
|
||||||
'''Store, save and load word vectors.'''
|
'''Store, save and load word vectors.'''
|
||||||
cdef public object data
|
cdef public object data
|
||||||
cdef readonly StringStore strings
|
cdef readonly StringStore strings
|
||||||
cdef public object key2i
|
cdef public object index
|
||||||
|
|
||||||
def __init__(self, strings, data_or_width):
|
def __init__(self, strings, data_or_width):
|
||||||
self.strings = StringStore()
|
self.strings = StringStore()
|
||||||
|
@ -22,9 +30,9 @@ cdef class Vectors:
|
||||||
else:
|
else:
|
||||||
data = data_or_width
|
data = data_or_width
|
||||||
self.data = data
|
self.data = data
|
||||||
self.key2i = {}
|
self.index = {}
|
||||||
for i, string in enumerate(strings):
|
for i, string in enumerate(strings):
|
||||||
self.key2i[self.strings.add(string)] = i
|
self.index[self.strings.add(string)] = i
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (Vectors, (self.strings, self.data))
|
return (Vectors, (self.strings, self.data))
|
||||||
|
@ -32,7 +40,7 @@ cdef class Vectors:
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if isinstance(key, basestring):
|
if isinstance(key, basestring):
|
||||||
key = self.strings[key]
|
key = self.strings[key]
|
||||||
i = self.key2i[key]
|
i = self.index[key]
|
||||||
if i is None:
|
if i is None:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
else:
|
else:
|
||||||
|
@ -41,7 +49,7 @@ cdef class Vectors:
|
||||||
def __setitem__(self, key, vector):
|
def __setitem__(self, key, vector):
|
||||||
if isinstance(key, basestring):
|
if isinstance(key, basestring):
|
||||||
key = self.strings.add(key)
|
key = self.strings.add(key)
|
||||||
i = self.key2i[key]
|
i = self.index[key]
|
||||||
self.data[i] = vector
|
self.data[i] = vector
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
@ -61,34 +69,119 @@ cdef class Vectors:
|
||||||
def most_similar(self, key):
|
def most_similar(self, key):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def to_disk(self, path):
|
def to_disk(self, path, **exclude):
|
||||||
raise NotImplementedError
|
def serialize_vectors(p):
|
||||||
|
write_vectors_to_bin_loc(self.strings, self.key2i, self.data, str(p))
|
||||||
|
|
||||||
def from_disk(self, path):
|
serializers = OrderedDict((
|
||||||
raise NotImplementedError
|
('vec.bin', serialize_vectors),
|
||||||
|
))
|
||||||
|
return util.to_disk(serializers, exclude)
|
||||||
|
|
||||||
|
def from_disk(self, path, **exclude):
|
||||||
|
def deserialize_vectors(p):
|
||||||
|
self.key2i, self.vectors = load_vectors_from_bin_loc(self.strings, str(p))
|
||||||
|
|
||||||
|
serializers = OrderedDict((
|
||||||
|
('vec.bin', deserialize_vectors)
|
||||||
|
))
|
||||||
|
return util.to_disk(serializers, exclude)
|
||||||
|
|
||||||
def to_bytes(self, **exclude):
|
def to_bytes(self, **exclude):
|
||||||
def serialize_weights():
|
def serialize_weights():
|
||||||
if hasattr(self.weights, 'to_bytes'):
|
if hasattr(self.data, 'to_bytes'):
|
||||||
return self.weights.to_bytes()
|
return self.data.to_bytes()
|
||||||
else:
|
else:
|
||||||
return msgpack.dumps(self.weights)
|
return msgpack.dumps(self.data)
|
||||||
|
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
|
('key2row', lambda: msgpack.dumps(self.key2i)),
|
||||||
('strings', lambda: self.strings.to_bytes()),
|
('strings', lambda: self.strings.to_bytes()),
|
||||||
('weights', serialize_weights)
|
('vectors', serialize_weights)
|
||||||
))
|
))
|
||||||
return util.to_bytes(serializers, exclude)
|
return util.to_bytes(serializers, exclude)
|
||||||
|
|
||||||
def from_bytes(self, data, **exclude):
|
def from_bytes(self, data, **exclude):
|
||||||
def deserialize_weights(b):
|
def deserialize_weights(b):
|
||||||
if hasattr(self.weights, 'from_bytes'):
|
if hasattr(self.data, 'from_bytes'):
|
||||||
self.weights.from_bytes()
|
self.data.from_bytes()
|
||||||
else:
|
else:
|
||||||
self.weights = msgpack.loads(b)
|
self.data = msgpack.loads(b)
|
||||||
|
|
||||||
deserializers = OrderedDict((
|
deserializers = OrderedDict((
|
||||||
|
('key2row', lambda b: self.key2i.update(msgpack.loads(b))),
|
||||||
('strings', lambda b: self.strings.from_bytes(b)),
|
('strings', lambda b: self.strings.from_bytes(b)),
|
||||||
('weights', deserialize_weights)
|
('vectors', deserialize_weights)
|
||||||
))
|
))
|
||||||
return util.from_bytes(deserializers, exclude)
|
return util.from_bytes(deserializers, exclude)
|
||||||
|
|
||||||
|
|
||||||
|
def write_vectors_to_bin_loc(StringStore strings, dict key2i,
|
||||||
|
np.ndarray vectors, out_loc):
|
||||||
|
|
||||||
|
cdef int32_t vec_len = vectors.shape[1]
|
||||||
|
cdef int32_t word_len
|
||||||
|
cdef bytes word_str
|
||||||
|
cdef char* chars
|
||||||
|
cdef uint64_t key
|
||||||
|
cdef int32_t i
|
||||||
|
cdef float* vec
|
||||||
|
|
||||||
|
cdef CFile out_file = CFile(out_loc, 'wb')
|
||||||
|
keys = [(i, key) for (key, i) in key2i.item()]
|
||||||
|
keys.sort()
|
||||||
|
for i, key in keys:
|
||||||
|
vec = <float*>vectors.data[i * vec_len]
|
||||||
|
word_str = strings[key].encode('utf8')
|
||||||
|
word_len = len(word_str)
|
||||||
|
|
||||||
|
out_file.write_from(&word_len, 1, sizeof(word_len))
|
||||||
|
out_file.write_from(&vec_len, 1, sizeof(vec_len))
|
||||||
|
|
||||||
|
chars = <char*>word_str
|
||||||
|
out_file.write_from(chars, word_len, sizeof(char))
|
||||||
|
out_file.write_from(vec, vec_len, sizeof(float))
|
||||||
|
out_file.close()
|
||||||
|
|
||||||
|
|
||||||
|
def load_vectors_from_bin_loc(StringStore strings, loc):
|
||||||
|
"""
|
||||||
|
Load vectors from the location of a binary file.
|
||||||
|
Arguments:
|
||||||
|
loc (unicode): The path of the binary file to load from.
|
||||||
|
Returns:
|
||||||
|
vec_len (int): The length of the vectors loaded.
|
||||||
|
"""
|
||||||
|
cdef CFile file_ = CFile(loc, b'rb')
|
||||||
|
cdef int32_t word_len
|
||||||
|
cdef int32_t vec_len = 0
|
||||||
|
cdef int32_t prev_vec_len = 0
|
||||||
|
cdef float* vec
|
||||||
|
cdef attr_t string_id
|
||||||
|
cdef bytes py_word
|
||||||
|
cdef vector[float*] vectors
|
||||||
|
cdef int line_num = 0
|
||||||
|
cdef Pool mem = Pool()
|
||||||
|
cdef dict key2i = {}
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
file_.read_into(&word_len, sizeof(word_len), 1)
|
||||||
|
except IOError:
|
||||||
|
break
|
||||||
|
file_.read_into(&vec_len, sizeof(vec_len), 1)
|
||||||
|
if prev_vec_len != 0 and vec_len != prev_vec_len:
|
||||||
|
raise Exception("Mismatched vector sizes")
|
||||||
|
if 0 >= vec_len >= MAX_VEC_SIZE:
|
||||||
|
raise Exception("Mismatched vector sizes")
|
||||||
|
|
||||||
|
chars = <char*>file_.alloc_read(mem, word_len, sizeof(char))
|
||||||
|
vec = <float*>file_.alloc_read(mem, vec_len, sizeof(float))
|
||||||
|
|
||||||
|
key = strings.add(chars[:word_len])
|
||||||
|
key2i[key] = vectors.size()
|
||||||
|
vectors.push_back(vec)
|
||||||
|
numpy_vectors = numpy.zeros((vectors.size(), vec_len), dtype='f')
|
||||||
|
for i in range(vectors.size()):
|
||||||
|
for j in range(vec_len):
|
||||||
|
numpy_vectors[i, j] = vectors[i][j]
|
||||||
|
return key2i, numpy_vectors
|
||||||
|
|
Loading…
Reference in New Issue
Block a user