From 43353b5413285ff9409978c5336d88b108d60ca2 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 3 Jun 2017 13:28:20 -0500 Subject: [PATCH] Improve train CLI script --- spacy/cli/train.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index a2c06c571..bc0664917 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -28,15 +28,17 @@ from .. import displacy n_iter=("number of iterations", "option", "n", int), n_sents=("number of sentences", "option", "ns", int), use_gpu=("Use GPU", "flag", "G", bool), + resume=("Whether to resume training", "flag", "R", bool), no_tagger=("Don't train tagger", "flag", "T", bool), no_parser=("Don't train parser", "flag", "P", bool), no_entities=("Don't train NER", "flag", "N", bool) ) def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, - use_gpu=False, no_tagger=False, no_parser=False, no_entities=False): + use_gpu=False, resume=False, no_tagger=False, no_parser=False, no_entities=False): """ Train a model. Expects data in spaCy's JSON format. """ + util.set_env_log(True) n_sents = n_sents or None output_path = util.ensure_path(output_dir) train_path = util.ensure_path(train_data) @@ -66,7 +68,11 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, util.env_opt('batch_to', 64), util.env_opt('batch_compound', 1.001)) - nlp = lang_class(pipeline=pipeline) + if resume: + prints(output_path / 'model19.pickle', title="Resuming training") + nlp = dill.load((output_path / 'model19.pickle').open('rb')) + else: + nlp = lang_class(pipeline=pipeline) corpus = GoldCorpus(train_path, dev_path, limit=n_sents) n_train_docs = corpus.count_train() @@ -75,6 +81,8 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %") try: for i in range(n_iter): + if resume: + i += 20 with tqdm.tqdm(total=corpus.count_train(), leave=False) as pbar: train_docs = corpus.train_docs(nlp, projectivize=True, gold_preproc=False, max_length=0) @@ -86,14 +94,18 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, pbar.update(len(docs)) with nlp.use_params(optimizer.averages): + util.set_env_log(False) + epoch_model_path = output_path / ('model%d' % i) + nlp.to_disk(epoch_model_path) with (output_path / ('model%d.pickle' % i)).open('wb') as file_: dill.dump(nlp, file_, -1) - with (output_path / ('model%d.bin' % i)).open('wb') as file_: - file_.write(nlp.to_bytes()) - with (output_path / ('model%d.bin' % i)).open('rb') as file_: - nlp_loaded = lang_class(pipeline=pipeline) - nlp_loaded.from_bytes(file_.read()) - scorer = nlp_loaded.evaluate(corpus.dev_docs(nlp_loaded, gold_preproc=False)) + nlp_loaded = lang_class(pipeline=pipeline) + nlp_loaded = nlp_loaded.from_disk(epoch_model_path) + scorer = nlp_loaded.evaluate( + corpus.dev_docs( + nlp_loaded, + gold_preproc=False)) + util.set_env_log(True) print_progress(i, losses, scorer.scores) finally: print("Saving model...")