From 52429625f0cb0d9731811fc92d6892c79295711d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 20 Mar 2015 01:14:20 +0100 Subject: [PATCH] * Add write_parses function --- bin/parser/train.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/bin/parser/train.py b/bin/parser/train.py index fdadc9d02..b4ae0c596 100755 --- a/bin/parser/train.py +++ b/bin/parser/train.py @@ -237,6 +237,7 @@ def train(Language, train_loc, model_dir, n_iter=15, feat_set=u'basic', seed=0, nlp.parser.model.end_training() nlp.entity.model.end_training() nlp.tagger.model.end_training() + print nlp.vocab.strings['NMOD'] def evaluate(Language, dev_loc, model_dir, gold_preproc=False, verbose=True): @@ -251,16 +252,34 @@ def evaluate(Language, dev_loc, model_dir, gold_preproc=False, verbose=True): return scorer +def write_parses(Language, dev_loc, model_dir, out_loc): + nlp = Language() + gold_tuples = read_docparse_file(dev_loc) + scorer = Scorer() + out_file = codecs.open(out_loc, 'w', 'utf8') + for raw_text, segmented_text, annot_tuples in gold_tuples: + tokens = nlp(raw_text) + for t in tokens: + out_file.write( + '%s\t%s\t%s\t%s\n' % (t.orth_, t.tag_, t.head.orth_, t.dep_) + ) + print nlp.vocab.strings['NMOD'] + return scorer + + @plac.annotations( train_loc=("Training file location",), dev_loc=("Dev. file location",), model_dir=("Location of output model directory",), + out_loc=("Out location", "option", "o", str), n_sents=("Number of training sentences", "option", "n", int), verbose=("Verbose error reporting", "flag", "v", bool), ) -def main(train_loc, dev_loc, model_dir, n_sents=0, verbose=False): - train(English, train_loc, model_dir, - gold_preproc=False, force_gold=False, n_sents=n_sents) +def main(train_loc, dev_loc, model_dir, n_sents=0, out_loc="", verbose=False): + #train(English, train_loc, model_dir, + # gold_preproc=False, force_gold=False, n_sents=n_sents) + if out_loc: + write_parses(English, dev_loc, model_dir, out_loc) scorer = evaluate(English, dev_loc, model_dir, gold_preproc=False, verbose=verbose) print 'POS', scorer.tags_acc print 'UAS', scorer.uas