mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Fix vectors deserialization
This commit is contained in:
parent
42d47c1e5c
commit
19c495f451
|
@ -20,7 +20,7 @@ cdef class Vectors:
|
||||||
'''Store, save and load word vectors.'''
|
'''Store, save and load word vectors.'''
|
||||||
cdef public object data
|
cdef public object data
|
||||||
cdef readonly StringStore strings
|
cdef readonly StringStore strings
|
||||||
cdef public object index
|
cdef public object key2row
|
||||||
|
|
||||||
def __init__(self, strings, data_or_width):
|
def __init__(self, strings, data_or_width):
|
||||||
self.strings = StringStore()
|
self.strings = StringStore()
|
||||||
|
@ -30,9 +30,9 @@ cdef class Vectors:
|
||||||
else:
|
else:
|
||||||
data = data_or_width
|
data = data_or_width
|
||||||
self.data = data
|
self.data = data
|
||||||
self.index = {}
|
self.key2row = {}
|
||||||
for i, string in enumerate(strings):
|
for i, string in enumerate(strings):
|
||||||
self.index[self.strings.add(string)] = i
|
self.key2row[self.strings.add(string)] = i
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (Vectors, (self.strings, self.data))
|
return (Vectors, (self.strings, self.data))
|
||||||
|
@ -40,7 +40,7 @@ cdef class Vectors:
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if isinstance(key, basestring):
|
if isinstance(key, basestring):
|
||||||
key = self.strings[key]
|
key = self.strings[key]
|
||||||
i = self.index[key]
|
i = self.key2row[key]
|
||||||
if i is None:
|
if i is None:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
else:
|
else:
|
||||||
|
@ -49,7 +49,7 @@ cdef class Vectors:
|
||||||
def __setitem__(self, key, vector):
|
def __setitem__(self, key, vector):
|
||||||
if isinstance(key, basestring):
|
if isinstance(key, basestring):
|
||||||
key = self.strings.add(key)
|
key = self.strings.add(key)
|
||||||
i = self.index[key]
|
i = self.key2row[key]
|
||||||
self.data[i] = vector
|
self.data[i] = vector
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
@ -71,7 +71,7 @@ cdef class Vectors:
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
def serialize_vectors(p):
|
def serialize_vectors(p):
|
||||||
write_vectors_to_bin_loc(self.strings, self.key2i, self.data, str(p))
|
write_vectors_to_bin_loc(self.strings, self.key2row, self.data, str(p))
|
||||||
|
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('vec.bin', serialize_vectors),
|
('vec.bin', serialize_vectors),
|
||||||
|
@ -80,12 +80,13 @@ cdef class Vectors:
|
||||||
|
|
||||||
def from_disk(self, path, **exclude):
|
def from_disk(self, path, **exclude):
|
||||||
def deserialize_vectors(p):
|
def deserialize_vectors(p):
|
||||||
self.key2i, self.vectors = load_vectors_from_bin_loc(self.strings, str(p))
|
values = load_vectors_from_bin_loc(self.strings, str(p))
|
||||||
|
self.key2row, self.data = values
|
||||||
|
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('vec.bin', deserialize_vectors)
|
('vec.bin', deserialize_vectors),
|
||||||
))
|
))
|
||||||
return util.to_disk(serializers, exclude)
|
return util.from_disk(path, serializers, exclude)
|
||||||
|
|
||||||
def to_bytes(self, **exclude):
|
def to_bytes(self, **exclude):
|
||||||
def serialize_weights():
|
def serialize_weights():
|
||||||
|
@ -93,9 +94,9 @@ cdef class Vectors:
|
||||||
return self.data.to_bytes()
|
return self.data.to_bytes()
|
||||||
else:
|
else:
|
||||||
return msgpack.dumps(self.data)
|
return msgpack.dumps(self.data)
|
||||||
|
b = msgpack.dumps(self.key2row)
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('key2row', lambda: msgpack.dumps(self.key2i)),
|
('key2row', lambda: msgpack.dumps(self.key2row)),
|
||||||
('strings', lambda: self.strings.to_bytes()),
|
('strings', lambda: self.strings.to_bytes()),
|
||||||
('vectors', serialize_weights)
|
('vectors', serialize_weights)
|
||||||
))
|
))
|
||||||
|
@ -109,7 +110,7 @@ cdef class Vectors:
|
||||||
self.data = msgpack.loads(b)
|
self.data = msgpack.loads(b)
|
||||||
|
|
||||||
deserializers = OrderedDict((
|
deserializers = OrderedDict((
|
||||||
('key2row', lambda b: self.key2i.update(msgpack.loads(b))),
|
('key2row', lambda b: self.key2row.update(msgpack.loads(b))),
|
||||||
('strings', lambda b: self.strings.from_bytes(b)),
|
('strings', lambda b: self.strings.from_bytes(b)),
|
||||||
('vectors', deserialize_weights)
|
('vectors', deserialize_weights)
|
||||||
))
|
))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user