Merge pull request #6123 from svlandeg/fix/frozen

This commit is contained in:
Ines Montani 2020-09-23 13:04:44 +02:00 committed by GitHub
commit af92d79e92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 4 deletions

View File

@ -152,7 +152,8 @@ def train(
exclude=frozen_components, exclude=frozen_components,
) )
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}") msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
print_row, finalize_logger = train_logger(nlp) with nlp.select_pipes(disable=frozen_components):
print_row, finalize_logger = train_logger(nlp)
try: try:
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False) progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
@ -163,7 +164,8 @@ def train(
progress.close() progress.close()
print_row(info) print_row(info)
if is_best_checkpoint and output_path is not None: if is_best_checkpoint and output_path is not None:
update_meta(T_cfg, nlp, info) with nlp.select_pipes(disable=frozen_components):
update_meta(T_cfg, nlp, info)
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
nlp.to_disk(output_path / "model-best") nlp.to_disk(output_path / "model-best")
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False) progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)

View File

@ -11,9 +11,11 @@ def console_logger():
def setup_printer( def setup_printer(
nlp: "Language", nlp: "Language",
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]: ) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
# we assume here that only components are enabled that should be trained & logged
logged_pipes = nlp.pipe_names
score_cols = list(nlp.config["training"]["score_weights"]) score_cols = list(nlp.config["training"]["score_weights"])
score_widths = [max(len(col), 6) for col in score_cols] score_widths = [max(len(col), 6) for col in score_cols]
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names] loss_cols = [f"Loss {pipe}" for pipe in logged_pipes]
loss_widths = [max(len(col), 8) for col in loss_cols] loss_widths = [max(len(col), 8) for col in loss_cols]
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"] table_header = ["E", "#"] + loss_cols + score_cols + ["Score"]
table_header = [col.upper() for col in table_header] table_header = [col.upper() for col in table_header]
@ -26,7 +28,7 @@ def console_logger():
try: try:
losses = [ losses = [
"{0:.2f}".format(float(info["losses"][pipe_name])) "{0:.2f}".format(float(info["losses"][pipe_name]))
for pipe_name in nlp.pipe_names for pipe_name in logged_pipes
] ]
except KeyError as e: except KeyError as e:
raise KeyError( raise KeyError(