mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Only add beam width if customised
This commit is contained in:
parent
7a354761c7
commit
c94742ff64
|
@ -143,6 +143,7 @@ def train(
|
|||
if 1 not in eval_beam_widths:
|
||||
eval_beam_widths.append(1)
|
||||
eval_beam_widths.sort()
|
||||
has_beam_widths = eval_beam_widths != [1]
|
||||
|
||||
# Set up the base model and pipeline. If a base model is specified, load
|
||||
# the model and make sure the pipeline matches the pipeline setting. If
|
||||
|
@ -211,11 +212,11 @@ def train(
|
|||
|
||||
# fmt: off
|
||||
row_head = ("Itn", "Beam Width", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS")
|
||||
row_settings = {
|
||||
"widths": (3, 10, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7),
|
||||
"aligns": tuple(["r" for i in row_head]),
|
||||
"spacing": 2
|
||||
}
|
||||
row_widths = (3, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7)
|
||||
if has_beam_widths:
|
||||
row_head.insert(1, "Beam W.")
|
||||
row_widths.insert(1, 7)
|
||||
row_settings = {"widths": row_widths, "aligns": tuple(["r" for i in row_head]), "spacing": 2}
|
||||
# fmt: on
|
||||
print("")
|
||||
msg.row(row_head, **row_settings)
|
||||
|
@ -318,8 +319,11 @@ def train(
|
|||
srsly.write_json(meta_loc, meta)
|
||||
util.set_env_log(verbose)
|
||||
|
||||
progress_args = [i, losses, scorer.scores]
|
||||
if has_beam_widths:
|
||||
progress_args.inset(1, beam_with)
|
||||
progress = _get_progress(
|
||||
i, beam_width, losses, scorer.scores, cpu_wps=cpu_wps, gpu_wps=gpu_wps
|
||||
*progress_args, cpu_wps=cpu_wps, gpu_wps=gpu_wps
|
||||
)
|
||||
msg.row(progress, **row_settings)
|
||||
finally:
|
||||
|
|
Loading…
Reference in New Issue
Block a user