Update train CLI

This commit is contained in:
Matthew Honnibal 2017-03-26 07:16:52 -05:00
parent 5eac089fbe
commit dec5571bf3
2 changed files with 32 additions and 18 deletions

View File

@ -81,16 +81,17 @@ class CLI(object):
output_dir=("output directory", "positional", None, str), output_dir=("output directory", "positional", None, str),
train_data=("training data", "positional", None, str), train_data=("training data", "positional", None, str),
dev_data=("development data", "positional", None, str), dev_data=("development data", "positional", None, str),
n_iter=("number of iterations", "flag", "n", int), n_iter=("number of iterations", "option", "n", int),
tagger=("train tagger", "flag", "t", bool), no_tagger=("Don't train tagger", "flag", "T", bool),
parser=("train parser", "flag", "p", bool), no_parser=("Don't train parser", "flag", "P", bool),
ner=("train NER", "flag", "n", bool) no_ner=("Don't train NER", "flag", "N", bool)
) )
def train(self, lang, output_dir, train_data, dev_data, n_iter=15, tagger=True, def train(self, lang, output_dir, train_data, dev_data, n_iter=15,
parser=True, ner=True): no_tagger=False, no_parser=False, no_ner=False):
"""Train a model.""" """Train a model."""
cli_train(output_dir, train_data, dev_data, tagger, parser, ner) cli_train(lang, output_dir, train_data, dev_data, n_iter,
not no_tagger, not no_parser, not no_ner)
@plac.annotations( @plac.annotations(

View File

@ -17,20 +17,28 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
output_path = Path(output_dir) output_path = Path(output_dir)
train_path = Path(train_data) train_path = Path(train_data)
dev_path = Path(dev_data) dev_path = Path(dev_data)
check_dirs(output_path, data_path, dev_path) check_dirs(output_path, train_path, dev_path)
lang = util.get_lang_class(language) lang = util.get_lang_class(language)
parser_cfg = dict(locals()) parser_cfg = {
tagger_cfg = dict(locals()) 'pseudoprojective': True,
entity_cfg = dict(locals()) 'n_iter': n_iter,
parser_cfg['features'] = lang.Defaults.parser_features 'lang': language,
entity_cfg['features'] = lang.Defaults.entity_features 'features': lang.Defaults.parser_features}
entity_cfg = {
'n_iter': n_iter,
'lang': language,
'features': lang.Defaults.entity_features}
tagger_cfg = {
'n_iter': n_iter,
'lang': language,
'features': lang.Defaults.tagger_features}
gold_train = list(read_gold_json(train_path)) gold_train = list(read_gold_json(train_path))
gold_dev = list(read_gold_json(dev_path)) gold_dev = list(read_gold_json(dev_path))
train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg, train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg,
entity_cfg, n_iter) entity_cfg, n_iter)
scorer = evaluate(lang, list(read_gold_json(dev_loc)), output_path) scorer = evaluate(lang, list(read_gold_json(dev_path)), output_path)
print_results(scorer) print_results(scorer)
@ -79,7 +87,7 @@ def evaluate(Language, gold_tuples, output_path):
return scorer return scorer
def check_dirs(input_path, train_path, dev_path): def check_dirs(output_path, train_path, dev_path):
if not output_path.exists(): if not output_path.exists():
util.sys_exit(output_path.as_posix(), title="Output directory not found") util.sys_exit(output_path.as_posix(), title="Output directory not found")
if not train_path.exists() and train_path.is_file(): if not train_path.exists() and train_path.is_file():
@ -92,7 +100,12 @@ def print_progress(itn, nr_weight, nr_active_feat, **scores):
def print_results(scorer): def print_results(scorer):
results = {'TOK': scorer.token_acc, 'POS': scorer.tags_acc, 'UAS': scorer.uas, results = {
'LAS': scorer.las, 'NER P': scorer.ents_p, 'NER R': scorer.ents_r, 'TOK': '%.2f' % scorer.token_acc,
'NER F': scorer.ents_f} 'POS': '%.2f' % scorer.tags_acc,
'UAS': '%.2f' % scorer.uas,
'LAS': '%.2f' % scorer.las,
'NER P': '%.2f' % scorer.ents_p,
'NER R': '%.2f' % scorer.ents_r,
'NER F': '%.2f' % scorer.ents_f}
util.print_table(results, title="Results") util.print_table(results, title="Results")