diff --git a/spacy/cli/init_model.py b/spacy/cli/init_model.py index c285a12a6..f2114085a 100644 --- a/spacy/cli/init_model.py +++ b/spacy/cli/init_model.py @@ -206,7 +206,7 @@ def read_vectors(vectors_loc): from tqdm import tqdm f = open_file(vectors_loc) - shape = tuple(int(size) for size in next(f).split()) + shape, f = _get_shape(f) vectors_data = numpy.zeros(shape=shape, dtype="f") vectors_keys = [] for i, line in enumerate(tqdm(f)): @@ -220,6 +220,21 @@ def read_vectors(vectors_loc): return vectors_data, vectors_keys +def _get_shape(file_): + """Return a tuple with (number of entries, vector dimensions). Handle + both word2vec/FastText format, which has a header with this, or GloVe's + format, which doesn't.""" + first_line = next(file_).split() + if len(first_line) == 2: + return tuple(int(size) for size in first_line), file_ + count = 1 + for line in file_: + count += 1 + file_.seek(0) + shape = (count, len(first_line)-1) + return shape, file_ + + def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50): # temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200 from tqdm import tqdm