mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			234 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			234 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import functools
 | |
| import operator
 | |
| from pathlib import Path
 | |
| import logging
 | |
| from typing import Optional, Tuple, Any, Dict, List
 | |
| 
 | |
| import numpy
 | |
| import wasabi.tables
 | |
| 
 | |
| from ..pipeline import TextCategorizer, MultiLabel_TextCategorizer
 | |
| from ..errors import Errors
 | |
| from ..training import Corpus
 | |
| from ._util import app, Arg, Opt, import_code, setup_gpu
 | |
| from .. import util
 | |
| 
 | |
| _DEFAULTS = {
 | |
|     "n_trials": 11,
 | |
|     "use_gpu": -1,
 | |
|     "gold_preproc": False,
 | |
| }
 | |
| 
 | |
| 
 | |
| @app.command(
 | |
|     "find-threshold",
 | |
|     context_settings={"allow_extra_args": False, "ignore_unknown_options": True},
 | |
| )
 | |
| def find_threshold_cli(
 | |
|     # fmt: off
 | |
|     model: str = Arg(..., help="Model name or path"),
 | |
|     data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
 | |
|     pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"),
 | |
|     threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"),
 | |
|     scores_key: str = Arg(..., help="Metric to optimize"),
 | |
|     n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
 | |
|     code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
 | |
|     use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
 | |
|     gold_preproc: bool = Opt(_DEFAULTS["gold_preproc"], "--gold-preproc", "-G", help="Use gold preprocessing"),
 | |
|     verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
 | |
|     # fmt: on
 | |
| ):
 | |
|     """
 | |
|     Runs prediction trials for a trained model with varying tresholds to maximize
 | |
|     the specified metric. The search space for the threshold is traversed linearly
 | |
|     from 0 to 1 in `n_trials` steps. Results are displayed in a table on `stdout`
 | |
|     (the corresponding API call to `spacy.cli.find_threshold.find_threshold()`
 | |
|     returns all results).
 | |
| 
 | |
|     This is applicable only for components whose predictions are influenced by
 | |
|     thresholds - e.g. `textcat_multilabel` and `spancat`, but not `textcat`. Note
 | |
|     that the full path to the corresponding threshold attribute in the config has to
 | |
|     be provided.
 | |
| 
 | |
|     DOCS: https://spacy.io/api/cli#find-threshold
 | |
|     """
 | |
| 
 | |
|     util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
 | |
|     import_code(code_path)
 | |
|     find_threshold(
 | |
|         model=model,
 | |
|         data_path=data_path,
 | |
|         pipe_name=pipe_name,
 | |
|         threshold_key=threshold_key,
 | |
|         scores_key=scores_key,
 | |
|         n_trials=n_trials,
 | |
|         use_gpu=use_gpu,
 | |
|         gold_preproc=gold_preproc,
 | |
|         silent=False,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def find_threshold(
 | |
|     model: str,
 | |
|     data_path: Path,
 | |
|     pipe_name: str,
 | |
|     threshold_key: str,
 | |
|     scores_key: str,
 | |
|     *,
 | |
|     n_trials: int = _DEFAULTS["n_trials"],  # type: ignore
 | |
|     use_gpu: int = _DEFAULTS["use_gpu"],  # type: ignore
 | |
|     gold_preproc: bool = _DEFAULTS["gold_preproc"],  # type: ignore
 | |
|     silent: bool = True,
 | |
| ) -> Tuple[float, float, Dict[float, float]]:
 | |
|     """
 | |
|     Runs prediction trials for models with varying tresholds to maximize the specified metric.
 | |
|     model (Union[str, Path]): Pipeline to evaluate. Can be a package or a path to a data directory.
 | |
|     data_path (Path): Path to file with DocBin with docs to use for threshold search.
 | |
|     pipe_name (str): Name of pipe to examine thresholds for.
 | |
|     threshold_key (str): Key of threshold attribute in component's configuration.
 | |
|     scores_key (str): Name of score to metric to optimize.
 | |
|     n_trials (int): Number of trials to determine optimal thresholds.
 | |
|     use_gpu (int): GPU ID or -1 for CPU.
 | |
|     gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the
 | |
|         tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
 | |
|         to train/test skew.
 | |
|     silent (bool): Whether to print non-error-related output to stdout.
 | |
|     RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
 | |
|         evaluated thresholds.
 | |
|     """
 | |
| 
 | |
|     setup_gpu(use_gpu, silent=silent)
 | |
|     data_path = util.ensure_path(data_path)
 | |
|     if not data_path.exists():
 | |
|         wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
 | |
|     nlp = util.load_model(model)
 | |
| 
 | |
|     if pipe_name not in nlp.component_names:
 | |
|         raise AttributeError(
 | |
|             Errors.E001.format(name=pipe_name, opts=nlp.component_names)
 | |
|         )
 | |
|     pipe = nlp.get_pipe(pipe_name)
 | |
|     if not hasattr(pipe, "scorer"):
 | |
|         raise AttributeError(Errors.E1045)
 | |
| 
 | |
|     if type(pipe) == TextCategorizer:
 | |
|         wasabi.msg.warn(
 | |
|             "The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
 | |
|             "exclusive classes. All thresholds will yield the same results."
 | |
|         )
 | |
| 
 | |
|     if not silent:
 | |
|         wasabi.msg.info(
 | |
|             title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
 | |
|             f"trials."
 | |
|         )
 | |
| 
 | |
|     # Load evaluation corpus.
 | |
|     corpus = Corpus(data_path, gold_preproc=gold_preproc)
 | |
|     dev_dataset = list(corpus(nlp))
 | |
|     config_keys = threshold_key.split(".")
 | |
| 
 | |
|     def set_nested_item(
 | |
|         config: Dict[str, Any], keys: List[str], value: float
 | |
|     ) -> Dict[str, Any]:
 | |
|         """Set item in nested dictionary. Adapted from https://stackoverflow.com/a/54138200.
 | |
|         config (Dict[str, Any]): Configuration dictionary.
 | |
|         keys (List[Any]): Path to value to set.
 | |
|         value (float): Value to set.
 | |
|         RETURNS (Dict[str, Any]): Updated dictionary.
 | |
|         """
 | |
|         functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
 | |
|         return config
 | |
| 
 | |
|     def filter_config(
 | |
|         config: Dict[str, Any], keys: List[str], full_key: str
 | |
|     ) -> Dict[str, Any]:
 | |
|         """Filters provided config dictionary so that only the specified keys path remains.
 | |
|         config (Dict[str, Any]): Configuration dictionary.
 | |
|         keys (List[Any]): Path to value to set.
 | |
|         full_key (str): Full user-specified key.
 | |
|         RETURNS (Dict[str, Any]): Filtered dictionary.
 | |
|         """
 | |
|         if keys[0] not in config:
 | |
|             wasabi.msg.fail(
 | |
|                 title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
 | |
|                 text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
 | |
|                 f"{list(config.keys())}",
 | |
|                 exits=1,
 | |
|             )
 | |
|         return {
 | |
|             keys[0]: filter_config(config[keys[0]], keys[1:], full_key)
 | |
|             if len(keys) > 1
 | |
|             else config[keys[0]]
 | |
|         }
 | |
| 
 | |
|     # Evaluate with varying threshold values.
 | |
|     scores: Dict[float, float] = {}
 | |
|     config_keys_full = ["components", pipe_name, *config_keys]
 | |
|     table_col_widths = (10, 10)
 | |
|     thresholds = numpy.linspace(0, 1, n_trials)
 | |
|     print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
 | |
|     for threshold in thresholds:
 | |
|         # Reload pipeline with overrides specifying the new threshold.
 | |
|         nlp = util.load_model(
 | |
|             model,
 | |
|             config=set_nested_item(
 | |
|                 filter_config(
 | |
|                     nlp.config, config_keys_full, ".".join(config_keys_full)
 | |
|                 ).copy(),
 | |
|                 config_keys_full,
 | |
|                 threshold,
 | |
|             ),
 | |
|         )
 | |
|         if hasattr(pipe, "cfg"):
 | |
|             setattr(
 | |
|                 nlp.get_pipe(pipe_name),
 | |
|                 "cfg",
 | |
|                 set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
 | |
|             )
 | |
| 
 | |
|         eval_scores = nlp.evaluate(dev_dataset)
 | |
|         if scores_key not in eval_scores:
 | |
|             wasabi.msg.fail(
 | |
|                 title=f"Failed to look up score `{scores_key}` in evaluation results.",
 | |
|                 text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
 | |
|                 f"available: {list(eval_scores.keys())}",
 | |
|                 exits=1,
 | |
|             )
 | |
|         scores[threshold] = eval_scores[scores_key]
 | |
| 
 | |
|         if not isinstance(scores[threshold], (float, int)):
 | |
|             wasabi.msg.fail(
 | |
|                 f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "
 | |
|                 f"scores.",
 | |
|                 exits=1,
 | |
|             )
 | |
|         print(
 | |
|             wasabi.row(
 | |
|                 [round(threshold, 3), round(scores[threshold], 3)],
 | |
|                 widths=table_col_widths,
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
 | |
| 
 | |
|     # If all scores are identical, emit warning.
 | |
|     if len(set(scores.values())) == 1:
 | |
|         wasabi.msg.warn(
 | |
|             title="All scores are identical. Verify that all settings are correct.",
 | |
|             text=""
 | |
|             if (
 | |
|                 not isinstance(pipe, MultiLabel_TextCategorizer)
 | |
|                 or scores_key in ("cats_macro_f", "cats_micro_f")
 | |
|             )
 | |
|             else "Use `cats_macro_f` or `cats_micro_f` when optimizing the threshold for `textcat_multilabel`.",
 | |
|         )
 | |
| 
 | |
|     else:
 | |
|         if not silent:
 | |
|             print(
 | |
|                 f"\nBest threshold: {round(best_threshold, ndigits=4)} with {scores_key} value of {scores[best_threshold]}."
 | |
|             )
 | |
| 
 | |
|     return best_threshold, scores[best_threshold], scores
 |