Ported model cli from v1

This commit is contained in:
Gyorgy Orosz 2017-08-19 21:45:23 +02:00
parent 5cb0200e63
commit e5344b83a3
3 changed files with 134 additions and 6 deletions

View File

@ -3,15 +3,21 @@ from __future__ import print_function
# NB! This breaks in plac on Python 2!! # NB! This breaks in plac on Python 2!!
#from __future__ import unicode_literals #from __future__ import unicode_literals
if __name__ == '__main__': if __name__ == '__main__':
import plac import plac
import sys import sys
from spacy.cli import download, link, info, package, train, convert from spacy.cli import download, link, info, package, train, convert, model
from spacy.util import prints from spacy.util import prints
commands = {'download': download, 'link': link, 'info': info, 'train': train, commands = {
'convert': convert, 'package': package} 'download': download,
'link': link,
'info': info,
'train': train,
'convert': convert,
'package': package,
'model': model
}
if len(sys.argv) == 1: if len(sys.argv) == 1:
prints(', '.join(commands), title="Available commands", exits=1) prints(', '.join(commands), title="Available commands", exits=1)
command = sys.argv.pop(1) command = sys.argv.pop(1)
@ -19,5 +25,7 @@ if __name__ == '__main__':
if command in commands: if command in commands:
plac.call(commands[command]) plac.call(commands[command])
else: else:
prints("Available: %s" % ', '.join(commands), prints(
title="Unknown command: %s" % command, exits=1) "Available: %s" % ', '.join(commands),
title="Unknown command: %s" % command,
exits=1)

View File

@ -4,3 +4,4 @@ from .link import link
from .package import package from .package import package
from .train import train from .train import train
from .convert import convert from .convert import convert
from .model import model

119
spacy/cli/model.py Normal file
View File

@ -0,0 +1,119 @@
# coding: utf8
from __future__ import unicode_literals
import gzip
import math
from ast import literal_eval
from pathlib import Path
from preshed.counter import PreshCounter
import spacy
from ..compat import fix_text
from .. import util
def model(cmd, lang, model_dir, freqs_data, clusters_data, vectors_data):
model_path = Path(model_dir)
freqs_path = Path(freqs_data)
clusters_path = Path(clusters_data) if clusters_data else None
vectors_path = Path(vectors_data) if vectors_data else None
check_dirs(freqs_path, clusters_path, vectors_path)
# vocab = util.get_lang_class(lang).Defaults.create_vocab()
nlp = spacy.blank(lang)
vocab = nlp.vocab
probs, oov_prob = read_probs(freqs_path)
clusters = read_clusters(clusters_path) if clusters_path else {}
populate_vocab(vocab, clusters, probs, oov_prob)
create_model(model_path, nlp)
def create_model(model_path, model):
if not model_path.exists():
model_path.mkdir()
model.to_disk(model_path.as_posix())
def read_probs(freqs_path, max_length=100, min_doc_freq=5, min_freq=200):
counts = PreshCounter()
total = 0
freqs_file = check_unzip(freqs_path)
for i, line in enumerate(freqs_file):
freq, doc_freq, key = line.rstrip().split('\t', 2)
freq = int(freq)
counts.inc(i + 1, freq)
total += freq
counts.smooth()
log_total = math.log(total)
freqs_file = check_unzip(freqs_path)
probs = {}
for line in freqs_file:
freq, doc_freq, key = line.rstrip().split('\t', 2)
doc_freq = int(doc_freq)
freq = int(freq)
if doc_freq >= min_doc_freq and freq >= min_freq and len(
key) < max_length:
word = literal_eval(key)
smooth_count = counts.smoother(int(freq))
probs[word] = math.log(smooth_count) - log_total
oov_prob = math.log(counts.smoother(0)) - log_total
return probs, oov_prob
def read_clusters(clusters_path):
clusters = {}
with clusters_path.open() as f:
for line in f:
try:
cluster, word, freq = line.split()
word = fix_text(word)
except ValueError:
continue
# If the clusterer has only seen the word a few times, its
# cluster is unreliable.
if int(freq) >= 3:
clusters[word] = cluster
else:
clusters[word] = '0'
# Expand clusters with re-casing
for word, cluster in list(clusters.items()):
if word.lower() not in clusters:
clusters[word.lower()] = cluster
if word.title() not in clusters:
clusters[word.title()] = cluster
if word.upper() not in clusters:
clusters[word.upper()] = cluster
return clusters
def populate_vocab(vocab, clusters, probs, oov_prob):
for word, prob in reversed(
sorted(list(probs.items()), key=lambda item: item[1])):
lexeme = vocab[word]
lexeme.prob = prob
lexeme.is_oov = False
# Decode as a little-endian string, so that we can do & 15 to get
# the first 4 bits. See _parse_features.pyx
if word in clusters:
lexeme.cluster = int(clusters[word][::-1], 2)
else:
lexeme.cluster = 0
def check_unzip(file_path):
file_path_str = file_path.as_posix()
if file_path_str.endswith('gz'):
return gzip.open(file_path_str)
else:
return file_path.open()
def check_dirs(freqs_data, clusters_data, vectors_data):
if not freqs_data.is_file():
util.sys_exit(freqs_data.as_posix(), title="No frequencies file found")
if clusters_data and not clusters_data.is_file():
util.sys_exit(
clusters_data.as_posix(), title="No Brown clusters file found")
if vectors_data and not vectors_data.is_file():
util.sys_exit(
vectors_data.as_posix(), title="No word vectors file found")