* Update bin/parser/train for printing output.

This commit is contained in:
Matthew Honnibal 2015-10-06 10:35:22 +11:00
parent 3d9f41c2c9
commit c503654ec1

View File

@ -148,8 +148,9 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
nlp.end_training(model_dir) nlp.end_training(model_dir)
print('done') print('done')
def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False, def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False,
beam_width=None): beam_width=None, cand_preproc=None):
nlp = Language(data_dir=model_dir) nlp = Language(data_dir=model_dir)
if beam_width is not None: if beam_width is not None:
nlp.parser.cfg.beam_width = beam_width nlp.parser.cfg.beam_width = beam_width
@ -166,16 +167,14 @@ def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False
nlp.entity(tokens) nlp.entity(tokens)
nlp.parser(tokens) nlp.parser(tokens)
else: else:
tokens = nlp(raw_text, merge_mwes=False) tokens = nlp(raw_text)
gold = GoldParse(tokens, annot_tuples) gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=verbose) scorer.score(tokens, gold, verbose=verbose)
return scorer return scorer
def write_parses(Language, dev_loc, model_dir, out_loc, beam_width=None): def write_parses(Language, dev_loc, model_dir, out_loc):
nlp = Language(data_dir=model_dir) nlp = Language(data_dir=model_dir)
if beam_width is not None:
nlp.parser.cfg.beam_width = beam_width
gold_tuples = read_json_file(dev_loc) gold_tuples = read_json_file(dev_loc)
scorer = Scorer() scorer = Scorer()
out_file = codecs.open(out_loc, 'w', 'utf8') out_file = codecs.open(out_loc, 'w', 'utf8')
@ -188,14 +187,16 @@ def write_parses(Language, dev_loc, model_dir, out_loc, beam_width=None):
nlp.entity(tokens) nlp.entity(tokens)
nlp.parser(tokens) nlp.parser(tokens)
else: else:
tokens = nlp(raw_text, merge_mwes=False) tokens = nlp(raw_text)
gold = GoldParse(tokens, annot_tuples) #gold = GoldParse(tokens, annot_tuples)
scorer.score(tokens, gold, verbose=False) #scorer.score(tokens, gold, verbose=False)
for t in tokens: for sent in tokens.sents:
for t in sent:
if not t.is_space:
out_file.write( out_file.write(
'%s\t%s\t%s\t%s\n' % (t.orth_, t.tag_, t.head.orth_, t.dep_) '%d\t%s\t%s\t%s\t%s\n' % (t.i, t.orth_, t.tag_, t.head.orth_, t.dep_)
) )
return scorer out_file.write('\n')
@plac.annotations( @plac.annotations(
@ -220,14 +221,15 @@ def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbos
gold_preproc=gold_preproc, n_sents=n_sents, gold_preproc=gold_preproc, n_sents=n_sents,
corruption_level=corruption_level, n_iter=n_iter, corruption_level=corruption_level, n_iter=n_iter,
verbose=verbose) verbose=verbose)
#if out_loc: if out_loc:
# write_parses(English, dev_loc, model_dir, out_loc, beam_width=beam_width) write_parses(English, dev_loc, model_dir, out_loc)
scorer = evaluate(English, list(read_json_file(dev_loc)), scorer = evaluate(English, list(read_json_file(dev_loc)),
model_dir, gold_preproc=gold_preproc, verbose=verbose) model_dir, gold_preproc=gold_preproc, verbose=verbose)
print('TOK', scorer.token_acc) print('TOK', scorer.token_acc)
print('POS', scorer.tags_acc) print('POS', scorer.tags_acc)
print('UAS', scorer.uas) print('UAS', scorer.uas)
print('LAS', scorer.las) print('LAS', scorer.las)
print('SBD', scorer.sbd_acc)
print('NER P', scorer.ents_p) print('NER P', scorer.ents_p)
print('NER R', scorer.ents_r) print('NER R', scorer.ents_r)