Add GloVe vectors reader

This commit is contained in:
Matthew Honnibal 2017-09-01 16:39:22 +02:00
parent 7e04b7f89c
commit 7742a6d559

View File

@ -90,6 +90,33 @@ cdef class Vectors:
def most_similar(self, key): def most_similar(self, key):
raise NotImplementedError raise NotImplementedError
def from_glove(self, path):
'''Load GloVe vectors from a directory. Assumes binary format,
that the vocab is in a vocab.txt, and that vectors are named
vectors.{size}.[fd].bin, e.g. vectors.128.f.bin for 128d float32
vectors, vectors.300.d.bin for 300d float64 (double) vectors, etc.
By default GloVe outputs 64-bit vectors.'''
path = util.ensure_path(path)
for name in path.iterdir():
if name.parts[-1].startswith('vectors'):
_, dims, dtype, _2 = name.parts[-1].split('.')
self.width = int(dims)
break
else:
raise IOError("Expected file named e.g. vectors.128.f.bin")
bin_loc = path / 'vectors.{dims}.{dtype}.bin'.format(dims=dims,
dtype=dtype)
with bin_loc.open('rb') as file_:
self.data = numpy.fromfile(file_, dtype='float64')
self.data = numpy.ascontiguousarray(self.data, dtype='float32')
n = 0
with (path / 'vocab.txt').open('r') as file_:
for line in file_:
self.add(line.strip())
n += 1
if (self.data.size % self.width) == 0:
self.data
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
serializers = OrderedDict(( serializers = OrderedDict((
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)), ('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),