Improve train CLI

This commit is contained in:
Matthew Honnibal 2017-06-04 20:18:37 -05:00
parent a053b1218e
commit c52fde40f4

View File

@ -18,6 +18,7 @@ from ..gold import GoldCorpus, minibatch
from ..util import prints
from .. import util
from .. import displacy
from ..compat import json_dumps
@plac.annotations(
@ -44,7 +45,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
train_path = util.ensure_path(train_data)
dev_path = util.ensure_path(dev_data)
if not output_path.exists():
prints(output_path, title="Output directory not found", exits=1)
output_path.mkdir()
if not train_path.exists():
prints(train_path, title="Training data not found", exits=1)
if dev_path and not dev_path.exists():
@ -74,7 +75,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
else:
nlp = lang_class(pipeline=pipeline)
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
n_train_docs = corpus.count_train()
n_train_words = corpus.count_train()
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
@ -83,7 +84,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
for i in range(n_iter):
if resume:
i += 20
with tqdm.tqdm(total=corpus.count_train(), leave=False) as pbar:
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
train_docs = corpus.train_docs(nlp, projectivize=True,
gold_preproc=False, max_length=0)
losses = {}
@ -91,7 +92,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
docs, golds = zip(*batch)
nlp.update(docs, golds, sgd=optimizer,
drop=next(dropout_rates), losses=losses)
pbar.update(len(docs))
pbar.update(sum(len(doc) for doc in docs))
with nlp.use_params(optimizer.averages):
util.set_env_log(False)
@ -105,6 +106,9 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
corpus.dev_docs(
nlp_loaded,
gold_preproc=False))
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
with acc_loc.open('w') as file_:
file_.write(json_dumps(scorer.scores))
util.set_env_log(True)
print_progress(i, losses, scorer.scores)
finally: