Refactor train script

This commit is contained in:
Matthew Honnibal 2017-09-20 19:17:10 -05:00
parent ffda38356a
commit 1d73dec8b1

View File

@ -8,6 +8,7 @@ import cytoolz
from pathlib import Path
import dill
import tqdm
from thinc.neural._classes.model import Model
from thinc.neural.optimizers import linear_decay
from timeit import default_timer as timer
@ -69,18 +70,20 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
batch_sizes = util.compounding(util.env_opt('batch_from', 1),
util.env_opt('batch_to', 64),
util.env_opt('batch_compound', 1.001))
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
n_train_words = corpus.count_train()
if not resume:
lang_class = util.get_lang_class(lang)
nlp = lang_class(pipeline=pipeline)
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
else:
print("Load resume")
nlp = _resume_model(lang, pipeline)
util.use_gpu(use_gpu)
nlp = _resume_model(lang, pipeline, corpus)
optimizer = nlp.resume_training(device=use_gpu)
lang_class = nlp.__class__
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
n_train_words = corpus.count_train()
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
nlp._optimizer = None
print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
@ -101,11 +104,11 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
util.set_env_log(False)
epoch_model_path = output_path / ('model%d' % i)
nlp.to_disk(epoch_model_path)
nlp_loaded = lang_class(pipeline=pipeline)
nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
scorer = nlp_loaded.evaluate(
#nlp_loaded = lang_class(pipeline=pipeline)
#nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
scorer = nlp.evaluate(
corpus.dev_docs(
nlp_loaded,
nlp,
gold_preproc=gold_preproc))
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
with acc_loc.open('w') as file_:
@ -114,19 +117,30 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
print_progress(i, losses, scorer.scores)
finally:
print("Saving model...")
try:
with (output_path / 'model-final.pickle').open('wb') as file_:
with nlp.use_params(optimizer.averages):
dill.dump(nlp, file_, -1)
except:
pass
def _resume_model(lang, pipeline):
def _resume_model(lang, pipeline, corpus):
nlp = util.load_model(lang)
pipes = {getattr(pipe, 'name', None) for pipe in nlp.pipeline}
for name in pipeline:
if name not in pipes:
factory = nlp.Defaults.factories[name]
nlp.pipeline.extend(factory(nlp))
for pipe in factory(nlp):
if hasattr(pipe, 'begin_training'):
pipe.begin_training(corpus.train_tuples,
pipeline=nlp.pipeline)
nlp.pipeline.append(pipe)
nlp.meta['pipeline'] = pipeline
if nlp.vocab.vectors.data.shape[1] >= 1:
nlp.vocab.vectors.data = Model.ops.asarray(
nlp.vocab.vectors.data)
return nlp