mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-07 13:44:55 +03:00
Incorporate feedback.
This commit is contained in:
parent
5500a58c00
commit
d080808bf7
|
@ -30,7 +30,7 @@ def find_threshold_cli(
|
||||||
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
|
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
|
||||||
pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"),
|
pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"),
|
||||||
threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"),
|
threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"),
|
||||||
scores_key: str = Arg(..., help="Name of score to metric to optimize"),
|
scores_key: str = Arg(..., help="Metric to optimize"),
|
||||||
n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
|
n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
|
||||||
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||||
use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
||||||
|
@ -39,12 +39,12 @@ def find_threshold_cli(
|
||||||
# fmt: on
|
# fmt: on
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric from CLI.
|
Runs prediction trials for models with varying tresholds to maximize the specified metric from CLI.
|
||||||
model (Path): Path to file with trained model.
|
model (Path): Path to file with trained model.
|
||||||
data_path (Path): Path to file with DocBin with docs to use for threshold search.
|
data_path (Path): Path to file with DocBin with docs to use for threshold search.
|
||||||
pipe_name (str): Name of pipe to examine thresholds for.
|
pipe_name (str): Name of pipe to examine thresholds for.
|
||||||
threshold_key (str): Key of threshold attribute in component's configuration.
|
threshold_key (str): Key of threshold attribute in component's configuration.
|
||||||
scores_key (str): Name of score to metric to optimize.
|
scores_key (str): Metric to optimize.
|
||||||
n_trials (int): Number of trials to determine optimal thresholds
|
n_trials (int): Number of trials to determine optimal thresholds
|
||||||
code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported.
|
code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported.
|
||||||
use_gpu (int): GPU ID or -1 for CPU.
|
use_gpu (int): GPU ID or -1 for CPU.
|
||||||
|
@ -82,7 +82,7 @@ def find_threshold(
|
||||||
silent: bool = True,
|
silent: bool = True,
|
||||||
) -> Tuple[float, float, Dict[float, float]]:
|
) -> Tuple[float, float, Dict[float, float]]:
|
||||||
"""
|
"""
|
||||||
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
|
Runs prediction trials for models with varying tresholds to maximize the specified metric.
|
||||||
model (Union[str, Path]): Path to file with trained model.
|
model (Union[str, Path]): Path to file with trained model.
|
||||||
data_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search.
|
data_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search.
|
||||||
pipe_name (str): Name of pipe to examine thresholds for.
|
pipe_name (str): Name of pipe to examine thresholds for.
|
||||||
|
@ -104,11 +104,9 @@ def find_threshold(
|
||||||
wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
|
wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
|
||||||
nlp = util.load_model(model)
|
nlp = util.load_model(model)
|
||||||
|
|
||||||
pipe: Optional[Any] = None
|
if pipe_name not in nlp.component_names:
|
||||||
try:
|
raise AttributeError(Errors.E001.format(name=pipe_name))
|
||||||
pipe = nlp.get_pipe(pipe_name)
|
pipe = nlp.get_pipe(pipe_name)
|
||||||
except KeyError as err:
|
|
||||||
wasabi.msg.fail(title=str(err), exits=1)
|
|
||||||
if not hasattr(pipe, "scorer"):
|
if not hasattr(pipe, "scorer"):
|
||||||
raise AttributeError(Errors.E1045)
|
raise AttributeError(Errors.E1045)
|
||||||
|
|
||||||
|
@ -141,14 +139,24 @@ def find_threshold(
|
||||||
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
|
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def filter_config(config: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
|
def filter_config(
|
||||||
|
config: Dict[str, Any], keys: List[str], full_key: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Filters provided config dictionary so that only the specified keys path remains.
|
"""Filters provided config dictionary so that only the specified keys path remains.
|
||||||
config (Dict[str, Any]): Configuration dictionary.
|
config (Dict[str, Any]): Configuration dictionary.
|
||||||
keys (List[Any]): Path to value to set.
|
keys (List[Any]): Path to value to set.
|
||||||
|
full_key (str): Full user-specified key.
|
||||||
RETURNS (Dict[str, Any]): Filtered dictionary.
|
RETURNS (Dict[str, Any]): Filtered dictionary.
|
||||||
"""
|
"""
|
||||||
|
if keys[0] not in config:
|
||||||
|
wasabi.msg.fail(
|
||||||
|
title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
|
||||||
|
text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
|
||||||
|
f"{config.keys()}",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
keys[0]: filter_config(config[keys[0]], keys[1:])
|
keys[0]: filter_config(config[keys[0]], keys[1:], full_key)
|
||||||
if len(keys) > 1
|
if len(keys) > 1
|
||||||
else config[keys[0]]
|
else config[keys[0]]
|
||||||
}
|
}
|
||||||
|
@ -164,7 +172,9 @@ def find_threshold(
|
||||||
nlp = util.load_model(
|
nlp = util.load_model(
|
||||||
model,
|
model,
|
||||||
config=set_nested_item(
|
config=set_nested_item(
|
||||||
filter_config(nlp.config, config_keys_full).copy(),
|
filter_config(
|
||||||
|
nlp.config, config_keys_full, ".".join(config_keys_full)
|
||||||
|
).copy(),
|
||||||
config_keys_full,
|
config_keys_full,
|
||||||
threshold,
|
threshold,
|
||||||
),
|
),
|
||||||
|
@ -176,7 +186,16 @@ def find_threshold(
|
||||||
set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
|
set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
|
||||||
)
|
)
|
||||||
|
|
||||||
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
|
eval_scores = nlp.evaluate(dev_dataset)
|
||||||
|
if scores_key not in eval_scores:
|
||||||
|
wasabi.msg.fail(
|
||||||
|
title=f"Failed to look up score `{scores_key}` in evaluation results.",
|
||||||
|
text=f"Make sure you specified the correct value for `scores_key` correctly. The following scores are "
|
||||||
|
f"available: {eval_scores.keys()}",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
scores[threshold] = eval_scores[scores_key]
|
||||||
|
|
||||||
if not isinstance(scores[threshold], (float, int)):
|
if not isinstance(scores[threshold], (float, int)):
|
||||||
wasabi.msg.fail(
|
wasabi.msg.fail(
|
||||||
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "
|
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "
|
||||||
|
|
Loading…
Reference in New Issue
Block a user