mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Add GloVe vectors reader
This commit is contained in:
parent
7e04b7f89c
commit
7742a6d559
|
@ -90,6 +90,33 @@ cdef class Vectors:
|
|||
def most_similar(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
def from_glove(self, path):
|
||||
'''Load GloVe vectors from a directory. Assumes binary format,
|
||||
that the vocab is in a vocab.txt, and that vectors are named
|
||||
vectors.{size}.[fd].bin, e.g. vectors.128.f.bin for 128d float32
|
||||
vectors, vectors.300.d.bin for 300d float64 (double) vectors, etc.
|
||||
By default GloVe outputs 64-bit vectors.'''
|
||||
path = util.ensure_path(path)
|
||||
for name in path.iterdir():
|
||||
if name.parts[-1].startswith('vectors'):
|
||||
_, dims, dtype, _2 = name.parts[-1].split('.')
|
||||
self.width = int(dims)
|
||||
break
|
||||
else:
|
||||
raise IOError("Expected file named e.g. vectors.128.f.bin")
|
||||
bin_loc = path / 'vectors.{dims}.{dtype}.bin'.format(dims=dims,
|
||||
dtype=dtype)
|
||||
with bin_loc.open('rb') as file_:
|
||||
self.data = numpy.fromfile(file_, dtype='float64')
|
||||
self.data = numpy.ascontiguousarray(self.data, dtype='float32')
|
||||
n = 0
|
||||
with (path / 'vocab.txt').open('r') as file_:
|
||||
for line in file_:
|
||||
self.add(line.strip())
|
||||
n += 1
|
||||
if (self.data.size % self.width) == 0:
|
||||
self.data
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
serializers = OrderedDict((
|
||||
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
|
||||
|
|
Loading…
Reference in New Issue
Block a user