From c94742ff6453a3edb1b4e8a839d6d2bb0e61c717 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Sat, 16 Mar 2019 15:55:31 +0100 Subject: [PATCH] Only add beam width if customised --- spacy/cli/train.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index c74ec8663..42965edd0 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -143,6 +143,7 @@ def train( if 1 not in eval_beam_widths: eval_beam_widths.append(1) eval_beam_widths.sort() + has_beam_widths = eval_beam_widths != [1] # Set up the base model and pipeline. If a base model is specified, load # the model and make sure the pipeline matches the pipeline setting. If @@ -211,11 +212,11 @@ def train( # fmt: off row_head = ("Itn", "Beam Width", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS") - row_settings = { - "widths": (3, 10, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7), - "aligns": tuple(["r" for i in row_head]), - "spacing": 2 - } + row_widths = (3, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7) + if has_beam_widths: + row_head.insert(1, "Beam W.") + row_widths.insert(1, 7) + row_settings = {"widths": row_widths, "aligns": tuple(["r" for i in row_head]), "spacing": 2} # fmt: on print("") msg.row(row_head, **row_settings) @@ -318,8 +319,11 @@ def train( srsly.write_json(meta_loc, meta) util.set_env_log(verbose) + progress_args = [i, losses, scorer.scores] + if has_beam_widths: + progress_args.inset(1, beam_with) progress = _get_progress( - i, beam_width, losses, scorer.scores, cpu_wps=cpu_wps, gpu_wps=gpu_wps + *progress_args, cpu_wps=cpu_wps, gpu_wps=gpu_wps ) msg.row(progress, **row_settings) finally: