Pass dev_scores to print_progress correctly (resolves #1008)

Only read scores attribute if command is used with dev_data, otherwise
default dev_scores to empty dict.
This commit is contained in:
ines 2017-04-23 15:57:53 +02:00
parent ade920c30f
commit 3a9710f356

View File

@ -62,10 +62,10 @@ def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)): for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
for doc, gold in epoch: for doc, gold in epoch:
trainer.update(doc, gold) trainer.update(doc, gold)
dev_scores = trainer.evaluate(dev_data) if dev_data else [] dev_scores = trainer.evaluate(dev_data).scores if dev_data else {}
print_progress(itn, trainer.nlp.parser.model.nr_weight, print_progress(itn, trainer.nlp.parser.model.nr_weight,
trainer.nlp.parser.model.nr_active_feat, trainer.nlp.parser.model.nr_active_feat,
**dev_scores.scores) **dev_scores)
def evaluate(Language, gold_tuples, output_path): def evaluate(Language, gold_tuples, output_path):