diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index ae43b991b..897964a88 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -573,3 +573,12 @@ def setup_gpu(use_gpu: int, silent=None) -> None: local_msg.info("Using CPU") if gpu_is_available(): local_msg.info("To switch to GPU 0, use the option: --gpu-id 0") + + +def _format_number(number: Union[int, float], ndigits: int = 2) -> str: + """Formats a number (float or int) rounding to `ndigits`, without truncating trailing 0s, + as happens with `round(number, ndigits)`""" + if isinstance(number, float): + return f"{number:.{ndigits}f}" + else: + return str(number) diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index bd05471b1..963d5b926 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -9,7 +9,7 @@ import typer import math from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides -from ._util import import_code, debug_cli +from ._util import import_code, debug_cli, _format_number from ..training import Example, remove_bilu_prefix from ..training.initialize import get_sourced_components from ..schemas import ConfigSchemaTraining @@ -989,7 +989,8 @@ def _get_kl_divergence(p: Counter, q: Counter) -> float: def _format_span_row(span_data: List[Dict], labels: List[str]) -> List[Any]: """Compile into one list for easier reporting""" d = { - label: [label] + list(round(d[label], 2) for d in span_data) for label in labels + label: [label] + list(_format_number(d[label]) for d in span_data) + for label in labels } return list(d.values()) @@ -1004,6 +1005,10 @@ def _get_span_characteristics( label: _gmean(l) for label, l in compiled_gold["spans_length"][spans_key].items() } + spans_per_type = { + label: len(spans) + for label, spans in compiled_gold["spans_per_type"][spans_key].items() + } min_lengths = [min(l) for l in compiled_gold["spans_length"][spans_key].values()] max_lengths = [max(l) for l in compiled_gold["spans_length"][spans_key].values()] @@ -1031,6 +1036,7 @@ def _get_span_characteristics( return { "sd": span_distinctiveness, "bd": sb_distinctiveness, + "spans_per_type": spans_per_type, "lengths": span_length, "min_length": min(min_lengths), "max_length": max(max_lengths), @@ -1045,12 +1051,15 @@ def _get_span_characteristics( def _print_span_characteristics(span_characteristics: Dict[str, Any]): """Print all span characteristics into a table""" - headers = ("Span Type", "Length", "SD", "BD") + headers = ("Span Type", "Length", "SD", "BD", "N") + # Wasabi has this at 30 by default, but we might have some long labels + max_col = max(30, max(len(label) for label in span_characteristics["labels"])) # Prepare table data with all span characteristics table_data = [ span_characteristics["lengths"], span_characteristics["sd"], span_characteristics["bd"], + span_characteristics["spans_per_type"], ] table = _format_span_row( span_data=table_data, labels=span_characteristics["labels"] @@ -1061,8 +1070,18 @@ def _print_span_characteristics(span_characteristics: Dict[str, Any]): span_characteristics["avg_sd"], span_characteristics["avg_bd"], ] - footer = ["Wgt. Average"] + [str(round(f, 2)) for f in footer_data] - msg.table(table, footer=footer, header=headers, divider=True) + + footer = ( + ["Wgt. Average"] + ["{:.2f}".format(round(f, 2)) for f in footer_data] + ["-"] + ) + msg.table( + table, + footer=footer, + header=headers, + divider=True, + aligns=["l"] + ["r"] * (len(footer_data) + 1), + max_col=max_col, + ) def _get_spans_length_freq_dist(