mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Pass dev_scores to print_progress correctly (resolves #1008)
Only read scores attribute if command is used with dev_data, otherwise default dev_scores to empty dict.
This commit is contained in:
parent
ade920c30f
commit
3a9710f356
|
@ -62,10 +62,10 @@ def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_
|
||||||
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
|
for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)):
|
||||||
for doc, gold in epoch:
|
for doc, gold in epoch:
|
||||||
trainer.update(doc, gold)
|
trainer.update(doc, gold)
|
||||||
dev_scores = trainer.evaluate(dev_data) if dev_data else []
|
dev_scores = trainer.evaluate(dev_data).scores if dev_data else {}
|
||||||
print_progress(itn, trainer.nlp.parser.model.nr_weight,
|
print_progress(itn, trainer.nlp.parser.model.nr_weight,
|
||||||
trainer.nlp.parser.model.nr_active_feat,
|
trainer.nlp.parser.model.nr_active_feat,
|
||||||
**dev_scores.scores)
|
**dev_scores)
|
||||||
|
|
||||||
|
|
||||||
def evaluate(Language, gold_tuples, output_path):
|
def evaluate(Language, gold_tuples, output_path):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user