Support GloVe vectors txt format in init-model

This commit is contained in:
Matthw Honnibal 2019-11-02 12:47:10 +01:00
parent 55f2241d72
commit 106b0b9a79

View File

@ -206,7 +206,7 @@ def read_vectors(vectors_loc):
from tqdm import tqdm from tqdm import tqdm
f = open_file(vectors_loc) f = open_file(vectors_loc)
shape = tuple(int(size) for size in next(f).split()) shape, f = _get_shape(f)
vectors_data = numpy.zeros(shape=shape, dtype="f") vectors_data = numpy.zeros(shape=shape, dtype="f")
vectors_keys = [] vectors_keys = []
for i, line in enumerate(tqdm(f)): for i, line in enumerate(tqdm(f)):
@ -220,6 +220,21 @@ def read_vectors(vectors_loc):
return vectors_data, vectors_keys return vectors_data, vectors_keys
def _get_shape(file_):
"""Return a tuple with (number of entries, vector dimensions). Handle
both word2vec/FastText format, which has a header with this, or GloVe's
format, which doesn't."""
first_line = next(file_).split()
if len(first_line) == 2:
return tuple(int(size) for size in first_line), file_
count = 1
for line in file_:
count += 1
file_.seek(0)
shape = (count, len(first_line)-1)
return shape, file_
def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50): def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200 # temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
from tqdm import tqdm from tqdm import tqdm