spaCy/spacy/vectors.pyx
2017-06-05 12:36:04 +02:00

96 lines
2.6 KiB
Cython

import numpy
from collections import OrderedDict
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
from .strings cimport StringStore
from . import util
cdef class Vectors:
'''Store, save and load word vectors.'''
cdef public object data
cdef readonly StringStore strings
cdef public object key2i
def __init__(self, strings, data_or_width):
self.strings = StringStore()
if isinstance(data_or_width, int):
self.data = data = numpy.zeros((len(strings), data_or_width),
dtype='f')
else:
data = data_or_width
self.data = data
self.key2i = {}
for i, string in enumerate(strings):
self.key2i[self.strings.add(string)] = i
def __reduce__(self):
return (Vectors, (self.strings, self.data))
def __getitem__(self, key):
if isinstance(key, basestring):
key = self.strings[key]
i = self.key2i[key]
if i is None:
raise KeyError(key)
else:
return self.data[i]
def __setitem__(self, key, vector):
if isinstance(key, basestring):
key = self.strings.add(key)
i = self.key2i[key]
self.data[i] = vector
print("Set", i, vector)
def __iter__(self):
yield from self.data
def __len__(self):
return len(self.strings)
def items(self):
for i, string in enumerate(self.strings):
yield string, self.data[i]
@property
def shape(self):
return self.data.shape
def most_similar(self, key):
raise NotImplementedError
def to_disk(self, path):
raise NotImplementedError
def from_disk(self, path):
raise NotImplementedError
def to_bytes(self, **exclude):
def serialize_weights():
if hasattr(self.weights, 'to_bytes'):
return self.weights.to_bytes()
else:
return msgpack.dumps(self.weights)
serializers = OrderedDict((
('strings', lambda: self.strings.to_bytes()),
('weights', serialize_weights)
))
return util.to_bytes(serializers, exclude)
def from_bytes(self, data, **exclude):
def deserialize_weights(b):
if hasattr(self.weights, 'from_bytes'):
self.weights.from_bytes()
else:
self.weights = msgpack.loads(b)
deserializers = OrderedDict((
('strings', lambda b: self.strings.from_bytes(b)),
('weights', deserialize_weights)
))
return util.from_bytes(deserializers, exclude)