Rename model command to init_model and fix formatting

This commit is contained in:
ines 2017-12-07 09:59:23 +01:00
parent 2feeb428d6
commit 82e80ff928
3 changed files with 13 additions and 17 deletions

View File

@ -7,7 +7,7 @@ if __name__ == '__main__':
import plac import plac
import sys import sys
from spacy.cli import download, link, info, package, train, convert from spacy.cli import download, link, info, package, train, convert
from spacy.cli import vocab, profile, evaluate, validate from spacy.cli import vocab, init_model, profile, evaluate, validate
from spacy.util import prints from spacy.util import prints
commands = { commands = {
@ -19,6 +19,7 @@ if __name__ == '__main__':
'convert': convert, 'convert': convert,
'package': package, 'package': package,
'vocab': vocab, 'vocab': vocab,
'init-model': init_model,
'profile': profile, 'profile': profile,
'validate': validate 'validate': validate
} }

View File

@ -7,4 +7,5 @@ from .train import train
from .evaluate import evaluate from .evaluate import evaluate
from .convert import convert from .convert import convert
from .vocab import make_vocab as vocab from .vocab import make_vocab as vocab
from .init_model import init_model
from .validate import validate from .validate import validate

View File

@ -3,18 +3,15 @@ from __future__ import unicode_literals
import plac import plac
import math import math
from tqdm import tqdm from tqdm import tqdm
import spacy
import numpy import numpy
from ast import literal_eval from ast import literal_eval
from pathlib import Path from pathlib import Path
from preshed.counter import PreshCounter from preshed.counter import PreshCounter
from spacy.compat import fix_text from ...compat import fix_text
from spacy.vectors import Vectors from ...vectors import Vectors
from spacy.util import prints, ensure_path from ...util import prints, ensure_path, get_lang_class
@plac.annotations( @plac.annotations(
@ -29,7 +26,7 @@ from spacy.util import prints, ensure_path
prune_vectors=("optional: number of vectors to prune to", prune_vectors=("optional: number of vectors to prune to",
"option", "V", int) "option", "V", int)
) )
def main(lang, output_dir, freqs_loc, clusters_loc=None, vectors_loc=None, prune_vectors=-1): def init_model(lang, output_dir, freqs_loc, clusters_loc=None, vectors_loc=None, prune_vectors=-1):
if not freqs_loc.exists(): if not freqs_loc.exists():
prints(freqs_loc, title="Can't find words frequencies file", exits=1) prints(freqs_loc, title="Can't find words frequencies file", exits=1)
clusters_loc = ensure_path(clusters_loc) clusters_loc = ensure_path(clusters_loc)
@ -48,8 +45,9 @@ def main(lang, output_dir, freqs_loc, clusters_loc=None, vectors_loc=None, prune
def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors): def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, prune_vectors):
prints("Creating model...") print("Creating model...")
nlp = spacy.blank(lang) lang_class = get_lang_class(lang)
nlp = lang_class()
for lexeme in nlp.vocab: for lexeme in nlp.vocab:
lexeme.rank = 0 lexeme.rank = 0
@ -80,7 +78,7 @@ def create_model(lang, probs, oov_prob, clusters, vectors_data, vector_keys, pru
def read_vectors(vectors_loc): def read_vectors(vectors_loc):
prints("Reading vectors...") print("Reading vectors...")
with vectors_loc.open() as f: with vectors_loc.open() as f:
shape = tuple(int(size) for size in f.readline().split()) shape = tuple(int(size) for size in f.readline().split())
vectors_data = numpy.zeros(shape=shape, dtype='f') vectors_data = numpy.zeros(shape=shape, dtype='f')
@ -94,7 +92,7 @@ def read_vectors(vectors_loc):
def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50): def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
prints("Counting frequencies...") print("Counting frequencies...")
counts = PreshCounter() counts = PreshCounter()
total = 0 total = 0
with freqs_loc.open() as f: with freqs_loc.open() as f:
@ -120,7 +118,7 @@ def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
def read_clusters(clusters_loc): def read_clusters(clusters_loc):
prints("Reading clusters...") print("Reading clusters...")
clusters = {} clusters = {}
with clusters_loc.open() as f: with clusters_loc.open() as f:
for line in tqdm(f): for line in tqdm(f):
@ -144,7 +142,3 @@ def read_clusters(clusters_loc):
if word.upper() not in clusters: if word.upper() not in clusters:
clusters[word.upper()] = cluster clusters[word.upper()] = cluster
return clusters return clusters
if __name__ == '__main__':
plac.call(main)