mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-07 21:54:54 +03:00
Change default value for n_trials. Log table iteratively.
This commit is contained in:
parent
67596fc58a
commit
9d947a42ce
|
@ -13,7 +13,7 @@ from ._util import app, Arg, Opt, import_code, setup_gpu
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
_DEFAULTS = {
|
_DEFAULTS = {
|
||||||
"n_trials": 10,
|
"n_trials": 11,
|
||||||
"use_gpu": -1,
|
"use_gpu": -1,
|
||||||
"gold_preproc": False,
|
"gold_preproc": False,
|
||||||
}
|
}
|
||||||
|
@ -109,7 +109,7 @@ def find_threshold(
|
||||||
except KeyError as err:
|
except KeyError as err:
|
||||||
wasabi.msg.fail(title=str(err), exits=1)
|
wasabi.msg.fail(title=str(err), exits=1)
|
||||||
if not hasattr(pipe, "scorer"):
|
if not hasattr(pipe, "scorer"):
|
||||||
raise AttributeError(Errors.E1047)
|
raise AttributeError(Errors.E1045)
|
||||||
|
|
||||||
if not silent:
|
if not silent:
|
||||||
wasabi.msg.info(
|
wasabi.msg.info(
|
||||||
|
@ -125,7 +125,7 @@ def find_threshold(
|
||||||
def set_nested_item(
|
def set_nested_item(
|
||||||
config: Dict[str, Any], keys: List[str], value: float
|
config: Dict[str, Any], keys: List[str], value: float
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Set item in nested dictionary. Adapated from https://stackoverflow.com/a/54138200.
|
"""Set item in nested dictionary. Adapted from https://stackoverflow.com/a/54138200.
|
||||||
config (Dict[str, Any]): Configuration dictionary.
|
config (Dict[str, Any]): Configuration dictionary.
|
||||||
keys (List[Any]): Path to value to set.
|
keys (List[Any]): Path to value to set.
|
||||||
value (float): Value to set.
|
value (float): Value to set.
|
||||||
|
@ -149,6 +149,8 @@ def find_threshold(
|
||||||
# Evaluate with varying threshold values.
|
# Evaluate with varying threshold values.
|
||||||
scores: Dict[float, float] = {}
|
scores: Dict[float, float] = {}
|
||||||
config_keys_full = ["components", pipe_name, *config_keys]
|
config_keys_full = ["components", pipe_name, *config_keys]
|
||||||
|
table_col_widths = (10, 10)
|
||||||
|
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
|
||||||
for threshold in numpy.linspace(0, 1, n_trials):
|
for threshold in numpy.linspace(0, 1, n_trials):
|
||||||
# Reload pipeline with overrides specifying the new threshold.
|
# Reload pipeline with overrides specifying the new threshold.
|
||||||
nlp = util.load_model(
|
nlp = util.load_model(
|
||||||
|
@ -173,15 +175,17 @@ def find_threshold(
|
||||||
f"scores.",
|
f"scores.",
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
wasabi.row(
|
||||||
|
[round(threshold, 3), round(scores[threshold], 3)],
|
||||||
|
widths=table_col_widths,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
|
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
|
||||||
if not silent:
|
if not silent:
|
||||||
print(
|
print(
|
||||||
f"Best threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}.",
|
f"\nBest threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}."
|
||||||
wasabi.tables.table(
|
|
||||||
data=[(threshold, score) for threshold, score in scores.items()],
|
|
||||||
header=["Threshold", f"{scores_key}"],
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return best_threshold, scores[best_threshold], scores
|
return best_threshold, scores[best_threshold], scores
|
||||||
|
|
Loading…
Reference in New Issue
Block a user