Make dev data optional

This commit is contained in:
ines 2017-03-26 11:48:17 +02:00
parent 0fc56e2544
commit 53cf2f1c0e

View File

@ -16,7 +16,7 @@ from .. import util
def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner):
output_path = Path(output_dir)
train_path = Path(train_data)
dev_path = Path(dev_data)
dev_path = Path(dev_data) if dev_data else None
check_dirs(output_path, data_path, dev_path)
lang = util.get_lang_class(language)
@ -26,12 +26,13 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne
parser_cfg['features'] = lang.Defaults.parser_features
entity_cfg['features'] = lang.Defaults.entity_features
gold_train = list(read_gold_json(train_path))
gold_dev = list(read_gold_json(dev_path))
gold_dev = list(read_gold_json(dev_path)) if dev_path else None
train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg,
entity_cfg, n_iter)
scorer = evaluate(lang, list(read_gold_json(dev_loc)), output_path)
print_results(scorer)
if gold_dev:
scorer = evaluate(lang, gold_dev, output_path)
print_results(scorer)
def train_config(config):
@ -54,7 +55,7 @@ def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
for doc, gold in epoch:
trainer.update(doc, gold)
dev_scores = trainer.evaluate(dev_data)
dev_scores = trainer.evaluate(dev_data) if dev_data else []
print_progress(itn, trainer.nlp.parser.model.nr_weight,
trainer.nlp.parser.model.nr_active_feat,
**dev_scores.scores)
@ -82,8 +83,10 @@ def evaluate(Language, gold_tuples, output_path):
def check_dirs(input_path, train_path, dev_path):
if not output_path.exists():
util.sys_exit(output_path.as_posix(), title="Output directory not found")
if not train_path.exists() and train_path.is_file():
if not train_path.exists() or not train_path.is_file():
util.sys_exit(train_path.as_posix(), title="Training data not found")
if dev_path and not dev_path.exists():
util.sys_exit(dev_path.as_posix(), title="Development data not found")
def print_progress(itn, nr_weight, nr_active_feat, **scores):