diff --git a/spacy/cli/init_model.py b/spacy/cli/init_model.py index e5a17d230..629fb92b8 100644 --- a/spacy/cli/init_model.py +++ b/spacy/cli/init_model.py @@ -10,6 +10,7 @@ from pathlib import Path from preshed.counter import PreshCounter import tarfile import gzip +import zipfile from ._messages import Messages from ..vectors import Vectors @@ -54,14 +55,19 @@ def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, vectors_loc= def open_file(loc): '''Handle .gz, .tar.gz or unzipped files''' loc = ensure_path(loc) + print("Open loc") if tarfile.is_tarfile(str(loc)): return tarfile.open(str(loc), 'r:gz') elif loc.parts[-1].endswith('gz'): return (line.decode('utf8') for line in gzip.open(str(loc), 'r')) + elif loc.parts[-1].endswith('zip'): + zip_file = zipfile.ZipFile(str(loc)) + names = zip_file.namelist() + file_ = zip_file.open(names[0]) + return (line.decode('utf8') for line in file_) else: return loc.open('r', encoding='utf8') - def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors): print("Creating model...") lang_class = get_lang_class(lang) @@ -104,8 +110,12 @@ def read_vectors(vectors_loc): vectors_data = numpy.zeros(shape=shape, dtype='f') vectors_keys = [] for i, line in enumerate(tqdm(f)): - pieces = line.split() + line = line.rstrip() + pieces = line.rsplit(' ', vectors_data.shape[1]+1) word = pieces.pop(0) + if len(pieces) != vectors_data.shape[1]: + print(word, repr(line)) + raise ValueError("Bad line in file") vectors_data[i] = numpy.asarray(pieces, dtype='f') vectors_keys.append(word) return vectors_data, vectors_keys