spaCy/spacy/vectors.pyx

150 lines
4.4 KiB
Cython
Raw Normal View History

2017-08-19 22:27:35 +03:00
from __future__ import unicode_literals
2017-08-18 21:45:48 +03:00
from libc.stdint cimport int32_t, uint64_t
2017-06-05 13:32:08 +03:00
import numpy
from collections import OrderedDict
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
2017-08-18 21:45:48 +03:00
cimport numpy as np
2017-06-05 13:32:08 +03:00
2017-08-18 21:45:48 +03:00
from .typedefs cimport attr_t
2017-06-05 13:32:08 +03:00
from .strings cimport StringStore
from . import util
2017-08-19 20:52:25 +03:00
from .compat import basestring_
2017-06-05 13:32:08 +03:00
cdef class Vectors:
'''Store, save and load word vectors.'''
cdef public object data
cdef readonly StringStore strings
2017-08-19 05:33:03 +03:00
cdef public object key2row
2017-08-19 19:42:11 +03:00
cdef public object keys
2017-08-19 21:35:33 +03:00
cdef public int i
2017-06-05 13:32:08 +03:00
def __init__(self, strings, data_or_width):
self.strings = StringStore()
if isinstance(data_or_width, int):
self.data = data = numpy.zeros((len(strings), data_or_width),
dtype='f')
else:
data = data_or_width
2017-08-19 21:35:33 +03:00
self.i = 0
2017-06-05 13:32:08 +03:00
self.data = data
2017-08-19 05:33:03 +03:00
self.key2row = {}
2017-08-19 19:42:11 +03:00
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
2017-06-05 13:32:08 +03:00
def __reduce__(self):
2017-06-05 13:36:04 +03:00
return (Vectors, (self.strings, self.data))
2017-06-05 13:32:08 +03:00
def __getitem__(self, key):
if isinstance(key, basestring):
key = self.strings[key]
2017-08-19 05:33:03 +03:00
i = self.key2row[key]
2017-06-05 13:32:08 +03:00
if i is None:
raise KeyError(key)
else:
return self.data[i]
def __setitem__(self, key, vector):
if isinstance(key, basestring):
key = self.strings.add(key)
2017-08-19 05:33:03 +03:00
i = self.key2row[key]
2017-06-05 13:32:08 +03:00
self.data[i] = vector
def __iter__(self):
yield from self.data
def __len__(self):
2017-08-19 21:35:33 +03:00
return self.i
2017-08-19 20:52:25 +03:00
def __contains__(self, key):
if isinstance(key, basestring_):
key = self.strings[key]
return key in self.key2row
2017-08-19 21:35:33 +03:00
def add(self, key, vector=None):
if isinstance(key, basestring_):
key = self.strings.add(key)
if key not in self.key2row:
i = self.i
if i >= self.keys.shape[0]:
self.keys.resize((self.keys.shape[0]*2,))
self.data.resize((self.data.shape[0]*2, self.data.shape[1]))
self.key2row[key] = self.i
self.keys[self.i] = key
self.i += 1
else:
i = self.key2row[key]
2017-08-19 20:52:25 +03:00
if vector is not None:
2017-08-19 21:35:33 +03:00
self.data[i] = vector
return i
2017-06-05 13:32:08 +03:00
def items(self):
for i, string in enumerate(self.strings):
yield string, self.data[i]
@property
def shape(self):
return self.data.shape
def most_similar(self, key):
raise NotImplementedError
2017-08-18 21:45:48 +03:00
def to_disk(self, path, **exclude):
serializers = OrderedDict((
2017-08-19 20:52:25 +03:00
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)),
2017-08-18 21:45:48 +03:00
))
2017-08-19 19:42:11 +03:00
return util.to_disk(path, serializers, exclude)
2017-08-18 21:45:48 +03:00
def from_disk(self, path, **exclude):
2017-08-19 19:42:11 +03:00
def load_keys(path):
2017-08-19 23:07:00 +03:00
if path.exists():
self.keys = numpy.load(path)
for i, key in enumerate(self.keys):
self.keys[i] = key
self.key2row[key] = i
2017-08-19 19:42:11 +03:00
def load_vectors(path):
2017-08-19 23:07:00 +03:00
if path.exists():
self.data = numpy.load(path)
2017-08-18 21:45:48 +03:00
serializers = OrderedDict((
2017-08-19 19:42:11 +03:00
('keys', load_keys),
('vectors', load_vectors),
2017-08-18 21:45:48 +03:00
))
2017-08-19 19:42:11 +03:00
util.from_disk(path, serializers, exclude)
return self
2017-06-05 13:32:08 +03:00
def to_bytes(self, **exclude):
def serialize_weights():
2017-08-18 21:45:48 +03:00
if hasattr(self.data, 'to_bytes'):
return self.data.to_bytes()
2017-06-05 13:32:08 +03:00
else:
2017-08-18 21:45:48 +03:00
return msgpack.dumps(self.data)
2017-06-05 13:32:08 +03:00
serializers = OrderedDict((
2017-08-19 19:42:11 +03:00
('keys', lambda: msgpack.dumps(self.keys)),
2017-08-18 21:45:48 +03:00
('vectors', serialize_weights)
2017-06-05 13:32:08 +03:00
))
return util.to_bytes(serializers, exclude)
def from_bytes(self, data, **exclude):
def deserialize_weights(b):
2017-08-18 21:45:48 +03:00
if hasattr(self.data, 'from_bytes'):
self.data.from_bytes()
2017-06-05 13:32:08 +03:00
else:
2017-08-18 21:45:48 +03:00
self.data = msgpack.loads(b)
2017-06-05 13:32:08 +03:00
2017-08-19 19:42:11 +03:00
def load_keys(keys):
2017-08-19 22:27:35 +03:00
self.keys.resize((len(keys),))
2017-08-19 19:42:11 +03:00
for i, key in enumerate(keys):
self.keys[i] = key
self.key2row[key] = i
2017-06-05 13:32:08 +03:00
deserializers = OrderedDict((
2017-08-19 19:42:11 +03:00
('keys', lambda b: load_keys(msgpack.loads(b))),
2017-08-18 21:45:48 +03:00
('vectors', deserialize_weights)
2017-06-05 13:32:08 +03:00
))
2017-08-19 21:35:33 +03:00
util.from_bytes(data, deserializers, exclude)
2017-08-19 19:42:11 +03:00
return self