Port train CLI updates

Updates from #5362 and fix from #5387:

* `train`:

  * if training on GPU, only run evaluation/timing on CPU in the first
    iteration

  * if training is aborted, exit with a non-0 exit status
This commit is contained in:
Adriane Boyd 2020-06-03 14:03:43 +02:00
parent 581bda9f98
commit f1f9c8b417

View File

@ -458,22 +458,25 @@ def train(
cpu_wps = nwords / (end_time - start_time) cpu_wps = nwords / (end_time - start_time)
else: else:
gpu_wps = nwords / (end_time - start_time) gpu_wps = nwords / (end_time - start_time)
with use_ops("numpy"): # Evaluate on CPU in the first iteration only (for
nlp_loaded = util.load_model_from_path(epoch_model_path) # timing) when GPU is enabled
for name, component in nlp_loaded.pipeline: if i == 0:
if hasattr(component, "cfg"): with use_ops("numpy"):
component.cfg["beam_width"] = beam_width nlp_loaded = util.load_model_from_path(epoch_model_path)
dev_dataset = list( for name, component in nlp_loaded.pipeline:
corpus.dev_dataset( if hasattr(component, "cfg"):
nlp_loaded, component.cfg["beam_width"] = beam_width
gold_preproc=gold_preproc, dev_dataset = list(
ignore_misaligned=True, corpus.dev_dataset(
nlp_loaded,
gold_preproc=gold_preproc,
ignore_misaligned=True,
)
) )
) start_time = timer()
start_time = timer() scorer = nlp_loaded.evaluate(dev_dataset, verbose=verbose)
scorer = nlp_loaded.evaluate(dev_dataset, verbose=verbose) end_time = timer()
end_time = timer() cpu_wps = nwords / (end_time - start_time)
cpu_wps = nwords / (end_time - start_time)
acc_loc = output_path / f"model{i}" / "accuracy.json" acc_loc = output_path / f"model{i}" / "accuracy.json"
srsly.write_json(acc_loc, scorer.scores) srsly.write_json(acc_loc, scorer.scores)
@ -550,7 +553,7 @@ def train(
) )
break break
except Exception as e: except Exception as e:
msg.warn(f"Aborting and saving final best model. Encountered exception: {e}") msg.warn(f"Aborting and saving final best model. Encountered exception: {e}", exits=1)
finally: finally:
best_pipes = nlp.pipe_names best_pipes = nlp.pipe_names
if disabled_pipes: if disabled_pipes: