Incorporate feedback.

This commit is contained in:
Raphael Mitsch 2022-11-17 11:16:08 +01:00
parent 5500a58c00
commit d080808bf7

View File

@ -30,7 +30,7 @@ def find_threshold_cli(
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True), data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"), pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"),
threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"), threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"),
scores_key: str = Arg(..., help="Name of score to metric to optimize"), scores_key: str = Arg(..., help="Metric to optimize"),
n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"), n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"), code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"), use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
@ -39,12 +39,12 @@ def find_threshold_cli(
# fmt: on # fmt: on
): ):
""" """
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric from CLI. Runs prediction trials for models with varying tresholds to maximize the specified metric from CLI.
model (Path): Path to file with trained model. model (Path): Path to file with trained model.
data_path (Path): Path to file with DocBin with docs to use for threshold search. data_path (Path): Path to file with DocBin with docs to use for threshold search.
pipe_name (str): Name of pipe to examine thresholds for. pipe_name (str): Name of pipe to examine thresholds for.
threshold_key (str): Key of threshold attribute in component's configuration. threshold_key (str): Key of threshold attribute in component's configuration.
scores_key (str): Name of score to metric to optimize. scores_key (str): Metric to optimize.
n_trials (int): Number of trials to determine optimal thresholds n_trials (int): Number of trials to determine optimal thresholds
code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported. code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported.
use_gpu (int): GPU ID or -1 for CPU. use_gpu (int): GPU ID or -1 for CPU.
@ -82,7 +82,7 @@ def find_threshold(
silent: bool = True, silent: bool = True,
) -> Tuple[float, float, Dict[float, float]]: ) -> Tuple[float, float, Dict[float, float]]:
""" """
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric. Runs prediction trials for models with varying tresholds to maximize the specified metric.
model (Union[str, Path]): Path to file with trained model. model (Union[str, Path]): Path to file with trained model.
data_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search. data_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search.
pipe_name (str): Name of pipe to examine thresholds for. pipe_name (str): Name of pipe to examine thresholds for.
@ -104,11 +104,9 @@ def find_threshold(
wasabi.msg.fail("Evaluation data not found", data_path, exits=1) wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
nlp = util.load_model(model) nlp = util.load_model(model)
pipe: Optional[Any] = None if pipe_name not in nlp.component_names:
try: raise AttributeError(Errors.E001.format(name=pipe_name))
pipe = nlp.get_pipe(pipe_name) pipe = nlp.get_pipe(pipe_name)
except KeyError as err:
wasabi.msg.fail(title=str(err), exits=1)
if not hasattr(pipe, "scorer"): if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045) raise AttributeError(Errors.E1045)
@ -141,14 +139,24 @@ def find_threshold(
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
return config return config
def filter_config(config: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: def filter_config(
config: Dict[str, Any], keys: List[str], full_key: str
) -> Dict[str, Any]:
"""Filters provided config dictionary so that only the specified keys path remains. """Filters provided config dictionary so that only the specified keys path remains.
config (Dict[str, Any]): Configuration dictionary. config (Dict[str, Any]): Configuration dictionary.
keys (List[Any]): Path to value to set. keys (List[Any]): Path to value to set.
full_key (str): Full user-specified key.
RETURNS (Dict[str, Any]): Filtered dictionary. RETURNS (Dict[str, Any]): Filtered dictionary.
""" """
if keys[0] not in config:
wasabi.msg.fail(
title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
f"{config.keys()}",
exits=1,
)
return { return {
keys[0]: filter_config(config[keys[0]], keys[1:]) keys[0]: filter_config(config[keys[0]], keys[1:], full_key)
if len(keys) > 1 if len(keys) > 1
else config[keys[0]] else config[keys[0]]
} }
@ -164,7 +172,9 @@ def find_threshold(
nlp = util.load_model( nlp = util.load_model(
model, model,
config=set_nested_item( config=set_nested_item(
filter_config(nlp.config, config_keys_full).copy(), filter_config(
nlp.config, config_keys_full, ".".join(config_keys_full)
).copy(),
config_keys_full, config_keys_full,
threshold, threshold,
), ),
@ -176,7 +186,16 @@ def find_threshold(
set_nested_item(getattr(pipe, "cfg"), config_keys, threshold), set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
) )
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key] eval_scores = nlp.evaluate(dev_dataset)
if scores_key not in eval_scores:
wasabi.msg.fail(
title=f"Failed to look up score `{scores_key}` in evaluation results.",
text=f"Make sure you specified the correct value for `scores_key` correctly. The following scores are "
f"available: {eval_scores.keys()}",
exits=1,
)
scores[threshold] = eval_scores[scores_key]
if not isinstance(scores[threshold], (float, int)): if not isinstance(scores[threshold], (float, int)):
wasabi.msg.fail( wasabi.msg.fail(
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric " f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "