Change default value for n_trials. Log table iteratively.

This commit is contained in:
Raphael Mitsch 2022-10-24 11:25:51 +02:00
parent 67596fc58a
commit 9d947a42ce

View File

@ -13,7 +13,7 @@ from ._util import app, Arg, Opt, import_code, setup_gpu
from .. import util
_DEFAULTS = {
"n_trials": 10,
"n_trials": 11,
"use_gpu": -1,
"gold_preproc": False,
}
@ -109,7 +109,7 @@ def find_threshold(
except KeyError as err:
wasabi.msg.fail(title=str(err), exits=1)
if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1047)
raise AttributeError(Errors.E1045)
if not silent:
wasabi.msg.info(
@ -125,7 +125,7 @@ def find_threshold(
def set_nested_item(
config: Dict[str, Any], keys: List[str], value: float
) -> Dict[str, Any]:
"""Set item in nested dictionary. Adapated from https://stackoverflow.com/a/54138200.
"""Set item in nested dictionary. Adapted from https://stackoverflow.com/a/54138200.
config (Dict[str, Any]): Configuration dictionary.
keys (List[Any]): Path to value to set.
value (float): Value to set.
@ -149,6 +149,8 @@ def find_threshold(
# Evaluate with varying threshold values.
scores: Dict[float, float] = {}
config_keys_full = ["components", pipe_name, *config_keys]
table_col_widths = (10, 10)
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
for threshold in numpy.linspace(0, 1, n_trials):
# Reload pipeline with overrides specifying the new threshold.
nlp = util.load_model(
@ -173,15 +175,17 @@ def find_threshold(
f"scores.",
exits=1,
)
print(
wasabi.row(
[round(threshold, 3), round(scores[threshold], 3)],
widths=table_col_widths,
)
)
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
if not silent:
print(
f"Best threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}.",
wasabi.tables.table(
data=[(threshold, score) for threshold, score in scores.items()],
header=["Threshold", f"{scores_key}"],
),
f"\nBest threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}."
)
return best_threshold, scores[best_threshold], scores