From eb7cbb62c24be7573f2146e18d99117f3b071fde Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 5 Jun 2017 12:32:08 +0200 Subject: [PATCH] Flesh out Vectors class --- spacy/vectors.pyx | 95 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 spacy/vectors.pyx diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx new file mode 100644 index 000000000..36ab1e316 --- /dev/null +++ b/spacy/vectors.pyx @@ -0,0 +1,95 @@ +import numpy +from collections import OrderedDict +import msgpack +import msgpack_numpy +msgpack_numpy.patch() + +from .strings cimport StringStore +from . import util + + +cdef class Vectors: + '''Store, save and load word vectors.''' + cdef public object data + cdef readonly StringStore strings + cdef public object key2i + + def __init__(self, strings, data_or_width): + self.strings = StringStore() + if isinstance(data_or_width, int): + self.data = data = numpy.zeros((len(strings), data_or_width), + dtype='f') + else: + data = data_or_width + self.data = data + self.key2i = {} + for i, string in enumerate(strings): + self.key2i[self.strings.add(string)] = i + + def __reduce__(self): + raise NotImplementedError + + def __getitem__(self, key): + if isinstance(key, basestring): + key = self.strings[key] + i = self.key2i[key] + if i is None: + raise KeyError(key) + else: + return self.data[i] + + def __setitem__(self, key, vector): + if isinstance(key, basestring): + key = self.strings.add(key) + i = self.key2i[key] + self.data[i] = vector + print("Set", i, vector) + + def __iter__(self): + yield from self.data + + def __len__(self): + return len(self.strings) + + def items(self): + for i, string in enumerate(self.strings): + yield string, self.data[i] + + @property + def shape(self): + return self.data.shape + + def most_similar(self, key): + raise NotImplementedError + + def to_disk(self, path): + raise NotImplementedError + + def from_disk(self, path): + raise NotImplementedError + + def to_bytes(self, **exclude): + def serialize_weights(): + if hasattr(self.weights, 'to_bytes'): + return self.weights.to_bytes() + else: + return msgpack.dumps(self.weights) + + serializers = OrderedDict(( + ('strings', lambda: self.strings.to_bytes()), + ('weights', serialize_weights) + )) + return util.to_bytes(serializers, exclude) + + def from_bytes(self, data, **exclude): + def deserialize_weights(b): + if hasattr(self.weights, 'from_bytes'): + self.weights.from_bytes() + else: + self.weights = msgpack.loads(b) + + deserializers = OrderedDict(( + ('strings', lambda b: self.strings.from_bytes(b)), + ('weights', deserialize_weights) + )) + return util.from_bytes(deserializers, exclude)