Merge pull request #6239 from explosion/feature/results-table

Make console logger table more compact
This commit is contained in:
Ines Montani 2020-10-11 13:10:37 +02:00 committed by GitHub
commit f4e4eeb141
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -11,11 +11,25 @@ if TYPE_CHECKING:
from ..language import Language # noqa: F401 from ..language import Language # noqa: F401
def setup_table(
*, cols: List[str], widths: List[int], max_width: int = 13
) -> Tuple[List[str], List[int], List[str]]:
final_cols = []
final_widths = []
for col, width in zip(cols, widths):
if len(col) > max_width:
col = col[: max_width - 3] + "..." # shorten column if too long
final_cols.append(col.upper())
final_widths.append(max(len(col), width))
return final_cols, final_widths, ["r" for _ in final_widths]
@registry.loggers("spacy.ConsoleLogger.v1") @registry.loggers("spacy.ConsoleLogger.v1")
def console_logger(progress_bar: bool = False): def console_logger(progress_bar: bool = False):
def setup_printer( def setup_printer(
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]: ) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
write = lambda text: stdout.write(f"{text}\n")
msg = Printer(no_print=True) msg = Printer(no_print=True)
# ensure that only trainable components are logged # ensure that only trainable components are logged
logged_pipes = [ logged_pipes = [
@ -26,15 +40,14 @@ def console_logger(progress_bar: bool = False):
eval_frequency = nlp.config["training"]["eval_frequency"] eval_frequency = nlp.config["training"]["eval_frequency"]
score_weights = nlp.config["training"]["score_weights"] score_weights = nlp.config["training"]["score_weights"]
score_cols = [col for col, value in score_weights.items() if value is not None] score_cols = [col for col, value in score_weights.items() if value is not None]
score_widths = [max(len(col), 6) for col in score_cols]
loss_cols = [f"Loss {pipe}" for pipe in logged_pipes] loss_cols = [f"Loss {pipe}" for pipe in logged_pipes]
loss_widths = [max(len(col), 8) for col in loss_cols] spacing = 2
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"] table_header, table_widths, table_aligns = setup_table(
table_header = [col.upper() for col in table_header] cols=["E", "#"] + loss_cols + score_cols + ["Score"],
table_widths = [3, 6] + loss_widths + score_widths + [6] widths=[3, 6] + [8 for _ in loss_cols] + [6 for _ in score_cols] + [6],
table_aligns = ["r" for _ in table_widths] )
stdout.write(msg.row(table_header, widths=table_widths) + "\n") write(msg.row(table_header, widths=table_widths, spacing=spacing))
stdout.write(msg.row(["-" * width for width in table_widths]) + "\n") write(msg.row(["-" * width for width in table_widths], spacing=spacing))
progress = None progress = None
def log_step(info: Optional[Dict[str, Any]]) -> None: def log_step(info: Optional[Dict[str, Any]]) -> None:
@ -70,7 +83,9 @@ def console_logger(progress_bar: bool = False):
) )
if progress is not None: if progress is not None:
progress.close() progress.close()
stdout.write(msg.row(data, widths=table_widths, aligns=table_aligns) + "\n") write(
msg.row(data, widths=table_widths, aligns=table_aligns, spacing=spacing)
)
if progress_bar: if progress_bar:
# Set disable=None, so that it disables on non-TTY # Set disable=None, so that it disables on non-TTY
progress = tqdm.tqdm( progress = tqdm.tqdm(