From 7ee880a0ade4c690afe74eaf09e40818ccc2a470 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 10 Apr 2018 14:30:04 +0000 Subject: [PATCH] Add support for .zip to init_model --- spacy/cli/init_model.py | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/spacy/cli/init_model.py b/spacy/cli/init_model.py index e5a17d230..8a603fbcb 100644 --- a/spacy/cli/init_model.py +++ b/spacy/cli/init_model.py @@ -10,17 +10,12 @@ from pathlib import Path from preshed.counter import PreshCounter import tarfile import gzip +import zipfile -from ._messages import Messages +from ..compat import fix_text from ..vectors import Vectors -from ..errors import Warnings, user_warning from ..util import prints, ensure_path, get_lang_class -try: - import ftfy -except ImportError: - ftfy = None - @plac.annotations( lang=("model language", "positional", None, str), @@ -39,13 +34,16 @@ def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, vectors_loc= and word vectors. """ if freqs_loc is not None and not freqs_loc.exists(): - prints(freqs_loc, title=Messages.M037, exits=1) + prints(freqs_loc, title="Can't find words frequencies file", exits=1) clusters_loc = ensure_path(clusters_loc) vectors_loc = ensure_path(vectors_loc) + probs, oov_prob = read_freqs(freqs_loc) if freqs_loc is not None else ({}, -20) vectors_data, vector_keys = read_vectors(vectors_loc) if vectors_loc else (None, None) clusters = read_clusters(clusters_loc) if clusters_loc else {} + nlp = create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors) + if not output_dir.exists(): output_dir.mkdir() nlp.to_disk(output_dir) @@ -54,20 +52,26 @@ def init_model(lang, output_dir, freqs_loc=None, clusters_loc=None, vectors_loc= def open_file(loc): '''Handle .gz, .tar.gz or unzipped files''' loc = ensure_path(loc) + print("Open loc") if tarfile.is_tarfile(str(loc)): return tarfile.open(str(loc), 'r:gz') elif loc.parts[-1].endswith('gz'): return (line.decode('utf8') for line in gzip.open(str(loc), 'r')) + elif loc.parts[-1].endswith('zip'): + zip_file = zipfile.ZipFile(str(loc)) + names = zip_file.namelist() + file_ = zip_file.open(names[0]) + return (line.decode('utf8') for line in file_) else: return loc.open('r', encoding='utf8') - def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors): print("Creating model...") lang_class = get_lang_class(lang) nlp = lang_class() for lexeme in nlp.vocab: lexeme.rank = 0 + lex_added = 0 for i, (word, prob) in enumerate(tqdm(sorted(probs.items(), key=lambda item: item[1], reverse=True))): lexeme = nlp.vocab[word] @@ -87,13 +91,15 @@ def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, pru lexeme = nlp.vocab[word] lexeme.is_oov = False lex_added += 1 + if len(vectors_data): nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys) if prune_vectors >= 1: nlp.vocab.prune_vectors(prune_vectors) vec_added = len(nlp.vocab.vectors) - prints(Messages.M039.format(entries=lex_added, vectors=vec_added), - title=Messages.M038) + + prints("{} entries, {} vectors".format(lex_added, vec_added), + title="Sucessfully compiled vocab") return nlp @@ -104,8 +110,12 @@ def read_vectors(vectors_loc): vectors_data = numpy.zeros(shape=shape, dtype='f') vectors_keys = [] for i, line in enumerate(tqdm(f)): - pieces = line.split() + line = line.rstrip() + pieces = line.rsplit(' ', vectors_data.shape[1]+1) word = pieces.pop(0) + if len(pieces) != vectors_data.shape[1]: + print(word, repr(line)) + raise ValueError("Bad line in file") vectors_data[i] = numpy.asarray(pieces, dtype='f') vectors_keys.append(word) return vectors_data, vectors_keys @@ -140,14 +150,11 @@ def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50): def read_clusters(clusters_loc): print("Reading clusters...") clusters = {} - if ftfy is None: - user_warning(Warnings.W004) with clusters_loc.open() as f: for line in tqdm(f): try: cluster, word, freq = line.split() - if ftfy is not None: - word = ftfy.fix_text(word) + word = fix_text(word) except ValueError: continue # If the clusterer has only seen the word a few times, its