Save models properly in conll_train.py

This commit is contained in:
Matthew Honnibal 2016-07-31 11:42:17 +02:00
parent 5869f05bd6
commit e38632003d

View File

@ -171,8 +171,7 @@ def train(Language, gold_tuples, model_dir, dev_loc, n_iter=15, feat_set=u'basic
except KeyboardInterrupt:
print("Saving model...")
break
#nlp.end_training(model_dir)
nlp.parser.model.end_training()
nlp.end_training(model_dir)
print("Saved. Evaluating...")
return nlp
@ -215,7 +214,7 @@ def main(train_loc, dev_loc, model_dir, n_iter=15, neural=False, batch_norm=Fals
learn_rate=0.001, update_step='sgd_cm'):
with io.open(train_loc, 'r', encoding='utf8') as file_:
train_sents = list(read_conll(file_))
# preprocess training data here before ArcEager.get_labels() is called
# Preprocess training data here before ArcEager.get_labels() is called
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
nlp = train(English, train_sents, model_dir, dev_loc, n_iter=n_iter,
@ -225,11 +224,6 @@ def main(train_loc, dev_loc, model_dir, n_iter=15, neural=False, batch_norm=Fals
update_step=update_step)
scorer = score_file(nlp, dev_loc)
#scorer = Scorer()
#with io.open(dev_loc, 'r', encoding='utf8') as file_:
# for _, sents in read_conll(file_):
# for annot_tuples, _ in sents:
# score_model(scorer, nlp, None, annot_tuples)
print('TOK', scorer.token_acc)
print('POS', scorer.tags_acc)
print('UAS', scorer.uas)