* Ensure better separation between score printing and training in train.py

This commit is contained in:
Matthew Honnibal 2015-03-24 04:25:38 +01:00
parent 6d49f8717b
commit 221f43c370

View File

@ -218,6 +218,11 @@ def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0,
for itn in range(n_iter):
scorer = Scorer()
for raw_text, segmented_text, annot_tuples in gold_tuples:
# Eval before train
tokens = nlp(raw_text)
gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=False)
if gold_preproc:
sents = [nlp.tokenizer.tokens_from_list(s) for s in segmented_text]
else:
@ -229,15 +234,11 @@ def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0,
nlp.entity.train(tokens, gold, force_gold=force_gold)
nlp.tagger.train(tokens, gold.tags)
tokens = nlp(raw_text)
gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=False)
print '%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.ents_f, scorer.tags_acc)
random.shuffle(gold_tuples)
nlp.parser.model.end_training()
nlp.entity.model.end_training()
nlp.tagger.model.end_training()
print nlp.vocab.strings['NMOD']
def evaluate(Language, dev_loc, model_dir, gold_preproc=False, verbose=True):
@ -274,13 +275,16 @@ def write_parses(Language, dev_loc, model_dir, out_loc):
out_loc=("Out location", "option", "o", str),
n_sents=("Number of training sentences", "option", "n", int),
verbose=("Verbose error reporting", "flag", "v", bool),
debug=("Debug mode", "flag", "d", bool)
)
def main(train_loc, dev_loc, model_dir, n_sents=0, out_loc="", verbose=False):
train(English, train_loc, model_dir,
def main(train_loc, dev_loc, model_dir, n_sents=0, out_loc="", verbose=False,
debug=False):
train(English, train_loc, model_dir, feat_set='basic' if not debug else 'debug',
gold_preproc=False, force_gold=False, n_sents=n_sents)
if out_loc:
write_parses(English, dev_loc, model_dir, out_loc)
scorer = evaluate(English, dev_loc, model_dir, gold_preproc=False, verbose=verbose)
print 'TOK', scorer.mistokened
print 'POS', scorer.tags_acc
print 'UAS', scorer.uas
print 'LAS', scorer.las