debug data Spancat Table Improvements (#11504)

* update

* fix format function

* pull out _format_number

* format with black
This commit is contained in:
Peter Baumgartner 2022-09-28 11:16:05 -04:00 committed by GitHub
parent aea16719be
commit e794d4ae39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 5 deletions

View File

@ -573,3 +573,12 @@ def setup_gpu(use_gpu: int, silent=None) -> None:
local_msg.info("Using CPU")
if gpu_is_available():
local_msg.info("To switch to GPU 0, use the option: --gpu-id 0")
def _format_number(number: Union[int, float], ndigits: int = 2) -> str:
"""Formats a number (float or int) rounding to `ndigits`, without truncating trailing 0s,
as happens with `round(number, ndigits)`"""
if isinstance(number, float):
return f"{number:.{ndigits}f}"
else:
return str(number)

View File

@ -9,7 +9,7 @@ import typer
import math
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli
from ._util import import_code, debug_cli, _format_number
from ..training import Example, remove_bilu_prefix
from ..training.initialize import get_sourced_components
from ..schemas import ConfigSchemaTraining
@ -989,7 +989,8 @@ def _get_kl_divergence(p: Counter, q: Counter) -> float:
def _format_span_row(span_data: List[Dict], labels: List[str]) -> List[Any]:
"""Compile into one list for easier reporting"""
d = {
label: [label] + list(round(d[label], 2) for d in span_data) for label in labels
label: [label] + list(_format_number(d[label]) for d in span_data)
for label in labels
}
return list(d.values())
@ -1004,6 +1005,10 @@ def _get_span_characteristics(
label: _gmean(l)
for label, l in compiled_gold["spans_length"][spans_key].items()
}
spans_per_type = {
label: len(spans)
for label, spans in compiled_gold["spans_per_type"][spans_key].items()
}
min_lengths = [min(l) for l in compiled_gold["spans_length"][spans_key].values()]
max_lengths = [max(l) for l in compiled_gold["spans_length"][spans_key].values()]
@ -1031,6 +1036,7 @@ def _get_span_characteristics(
return {
"sd": span_distinctiveness,
"bd": sb_distinctiveness,
"spans_per_type": spans_per_type,
"lengths": span_length,
"min_length": min(min_lengths),
"max_length": max(max_lengths),
@ -1045,12 +1051,15 @@ def _get_span_characteristics(
def _print_span_characteristics(span_characteristics: Dict[str, Any]):
"""Print all span characteristics into a table"""
headers = ("Span Type", "Length", "SD", "BD")
headers = ("Span Type", "Length", "SD", "BD", "N")
# Wasabi has this at 30 by default, but we might have some long labels
max_col = max(30, max(len(label) for label in span_characteristics["labels"]))
# Prepare table data with all span characteristics
table_data = [
span_characteristics["lengths"],
span_characteristics["sd"],
span_characteristics["bd"],
span_characteristics["spans_per_type"],
]
table = _format_span_row(
span_data=table_data, labels=span_characteristics["labels"]
@ -1061,8 +1070,18 @@ def _print_span_characteristics(span_characteristics: Dict[str, Any]):
span_characteristics["avg_sd"],
span_characteristics["avg_bd"],
]
footer = ["Wgt. Average"] + [str(round(f, 2)) for f in footer_data]
msg.table(table, footer=footer, header=headers, divider=True)
footer = (
["Wgt. Average"] + ["{:.2f}".format(round(f, 2)) for f in footer_data] + ["-"]
)
msg.table(
table,
footer=footer,
header=headers,
divider=True,
aligns=["l"] + ["r"] * (len(footer_data) + 1),
max_col=max_col,
)
def _get_spans_length_freq_dist(