mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix beam width integration
This commit is contained in:
parent
c94742ff64
commit
e7aa25d9b1
|
@ -211,7 +211,7 @@ def train(
|
|||
msg.text("Loaded pretrained tok2vec for: {}".format(components))
|
||||
|
||||
# fmt: off
|
||||
row_head = ("Itn", "Beam Width", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS")
|
||||
row_head = ("Itn", "Dep Loss", "NER Loss", "UAS", "NER P", "NER R", "NER F", "Tag %", "Token %", "CPU WPS", "GPU WPS")
|
||||
row_widths = (3, 10, 10, 7, 7, 7, 7, 7, 7, 7, 7)
|
||||
if has_beam_widths:
|
||||
row_head.insert(1, "Beam W.")
|
||||
|
@ -319,11 +319,13 @@ def train(
|
|||
srsly.write_json(meta_loc, meta)
|
||||
util.set_env_log(verbose)
|
||||
|
||||
progress_args = [i, losses, scorer.scores]
|
||||
if has_beam_widths:
|
||||
progress_args.inset(1, beam_with)
|
||||
progress = _get_progress(
|
||||
*progress_args, cpu_wps=cpu_wps, gpu_wps=gpu_wps
|
||||
i,
|
||||
losses,
|
||||
scorer.scores,
|
||||
beam_width=beam_width if has_beam_widths else None,
|
||||
cpu_wps=cpu_wps,
|
||||
gpu_wps=gpu_wps,
|
||||
)
|
||||
msg.row(progress, **row_settings)
|
||||
finally:
|
||||
|
@ -411,7 +413,7 @@ def _get_metrics(component):
|
|||
return ("token_acc",)
|
||||
|
||||
|
||||
def _get_progress(itn, beam_width, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0):
|
||||
def _get_progress(itn, losses, dev_scores, beam_width=None, cpu_wps=0.0, gpu_wps=0.0):
|
||||
scores = {}
|
||||
for col in [
|
||||
"dep_loss",
|
||||
|
@ -432,9 +434,8 @@ def _get_progress(itn, beam_width, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0)
|
|||
scores.update(dev_scores)
|
||||
scores["cpu_wps"] = cpu_wps
|
||||
scores["gpu_wps"] = gpu_wps or 0.0
|
||||
return [
|
||||
result = [
|
||||
itn,
|
||||
beam_width,
|
||||
"{:.3f}".format(scores["dep_loss"]),
|
||||
"{:.3f}".format(scores["ner_loss"]),
|
||||
"{:.3f}".format(scores["uas"]),
|
||||
|
@ -446,3 +447,6 @@ def _get_progress(itn, beam_width, losses, dev_scores, cpu_wps=0.0, gpu_wps=0.0)
|
|||
"{:.0f}".format(scores["cpu_wps"]),
|
||||
"{:.0f}".format(scores["gpu_wps"]),
|
||||
]
|
||||
if beam_width is not None:
|
||||
result.insert(1, beam_width)
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue
Block a user