mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Flesh out Vectors class
This commit is contained in:
parent
a4dcc96c54
commit
eb7cbb62c2
95
spacy/vectors.pyx
Normal file
95
spacy/vectors.pyx
Normal file
|
@ -0,0 +1,95 @@
|
|||
import numpy
|
||||
from collections import OrderedDict
|
||||
import msgpack
|
||||
import msgpack_numpy
|
||||
msgpack_numpy.patch()
|
||||
|
||||
from .strings cimport StringStore
|
||||
from . import util
|
||||
|
||||
|
||||
cdef class Vectors:
|
||||
'''Store, save and load word vectors.'''
|
||||
cdef public object data
|
||||
cdef readonly StringStore strings
|
||||
cdef public object key2i
|
||||
|
||||
def __init__(self, strings, data_or_width):
|
||||
self.strings = StringStore()
|
||||
if isinstance(data_or_width, int):
|
||||
self.data = data = numpy.zeros((len(strings), data_or_width),
|
||||
dtype='f')
|
||||
else:
|
||||
data = data_or_width
|
||||
self.data = data
|
||||
self.key2i = {}
|
||||
for i, string in enumerate(strings):
|
||||
self.key2i[self.strings.add(string)] = i
|
||||
|
||||
def __reduce__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, basestring):
|
||||
key = self.strings[key]
|
||||
i = self.key2i[key]
|
||||
if i is None:
|
||||
raise KeyError(key)
|
||||
else:
|
||||
return self.data[i]
|
||||
|
||||
def __setitem__(self, key, vector):
|
||||
if isinstance(key, basestring):
|
||||
key = self.strings.add(key)
|
||||
i = self.key2i[key]
|
||||
self.data[i] = vector
|
||||
print("Set", i, vector)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.strings)
|
||||
|
||||
def items(self):
|
||||
for i, string in enumerate(self.strings):
|
||||
yield string, self.data[i]
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.data.shape
|
||||
|
||||
def most_similar(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
def to_disk(self, path):
|
||||
raise NotImplementedError
|
||||
|
||||
def from_disk(self, path):
|
||||
raise NotImplementedError
|
||||
|
||||
def to_bytes(self, **exclude):
|
||||
def serialize_weights():
|
||||
if hasattr(self.weights, 'to_bytes'):
|
||||
return self.weights.to_bytes()
|
||||
else:
|
||||
return msgpack.dumps(self.weights)
|
||||
|
||||
serializers = OrderedDict((
|
||||
('strings', lambda: self.strings.to_bytes()),
|
||||
('weights', serialize_weights)
|
||||
))
|
||||
return util.to_bytes(serializers, exclude)
|
||||
|
||||
def from_bytes(self, data, **exclude):
|
||||
def deserialize_weights(b):
|
||||
if hasattr(self.weights, 'from_bytes'):
|
||||
self.weights.from_bytes()
|
||||
else:
|
||||
self.weights = msgpack.loads(b)
|
||||
|
||||
deserializers = OrderedDict((
|
||||
('strings', lambda b: self.strings.from_bytes(b)),
|
||||
('weights', deserialize_weights)
|
||||
))
|
||||
return util.from_bytes(deserializers, exclude)
|
Loading…
Reference in New Issue
Block a user