Improve train CLI script

This commit is contained in:
Matthew Honnibal 2017-06-03 13:28:20 -05:00
parent d21459f87d
commit 43353b5413

View File

@ -28,15 +28,17 @@ from .. import displacy
n_iter=("number of iterations", "option", "n", int),
n_sents=("number of sentences", "option", "ns", int),
use_gpu=("Use GPU", "flag", "G", bool),
resume=("Whether to resume training", "flag", "R", bool),
no_tagger=("Don't train tagger", "flag", "T", bool),
no_parser=("Don't train parser", "flag", "P", bool),
no_entities=("Don't train NER", "flag", "N", bool)
)
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
use_gpu=False, no_tagger=False, no_parser=False, no_entities=False):
use_gpu=False, resume=False, no_tagger=False, no_parser=False, no_entities=False):
"""
Train a model. Expects data in spaCy's JSON format.
"""
util.set_env_log(True)
n_sents = n_sents or None
output_path = util.ensure_path(output_dir)
train_path = util.ensure_path(train_data)
@ -66,6 +68,10 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
util.env_opt('batch_to', 64),
util.env_opt('batch_compound', 1.001))
if resume:
prints(output_path / 'model19.pickle', title="Resuming training")
nlp = dill.load((output_path / 'model19.pickle').open('rb'))
else:
nlp = lang_class(pipeline=pipeline)
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
n_train_docs = corpus.count_train()
@ -75,6 +81,8 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
try:
for i in range(n_iter):
if resume:
i += 20
with tqdm.tqdm(total=corpus.count_train(), leave=False) as pbar:
train_docs = corpus.train_docs(nlp, projectivize=True,
gold_preproc=False, max_length=0)
@ -86,14 +94,18 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
pbar.update(len(docs))
with nlp.use_params(optimizer.averages):
util.set_env_log(False)
epoch_model_path = output_path / ('model%d' % i)
nlp.to_disk(epoch_model_path)
with (output_path / ('model%d.pickle' % i)).open('wb') as file_:
dill.dump(nlp, file_, -1)
with (output_path / ('model%d.bin' % i)).open('wb') as file_:
file_.write(nlp.to_bytes())
with (output_path / ('model%d.bin' % i)).open('rb') as file_:
nlp_loaded = lang_class(pipeline=pipeline)
nlp_loaded.from_bytes(file_.read())
scorer = nlp_loaded.evaluate(corpus.dev_docs(nlp_loaded, gold_preproc=False))
nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
scorer = nlp_loaded.evaluate(
corpus.dev_docs(
nlp_loaded,
gold_preproc=False))
util.set_env_log(True)
print_progress(i, losses, scorer.scores)
finally:
print("Saving model...")