diff --git a/spacy/cli/__init__.py b/spacy/cli/__init__.py index ce76ef9a9..aab2c8d12 100644 --- a/spacy/cli/__init__.py +++ b/spacy/cli/__init__.py @@ -27,6 +27,7 @@ from .project.dvc import project_update_dvc # noqa: F401 from .project.push import project_push # noqa: F401 from .project.pull import project_pull # noqa: F401 from .project.document import project_document # noqa: F401 +from .find_threshold import find_threshold # noqa: F401 @app.command("link", no_args_is_help=True, deprecated=True, hidden=True) diff --git a/spacy/cli/find_threshold.py b/spacy/cli/find_threshold.py new file mode 100644 index 000000000..efa664832 --- /dev/null +++ b/spacy/cli/find_threshold.py @@ -0,0 +1,233 @@ +import functools +import operator +from pathlib import Path +import logging +from typing import Optional, Tuple, Any, Dict, List + +import numpy +import wasabi.tables + +from ..pipeline import TextCategorizer, MultiLabel_TextCategorizer +from ..errors import Errors +from ..training import Corpus +from ._util import app, Arg, Opt, import_code, setup_gpu +from .. import util + +_DEFAULTS = { + "n_trials": 11, + "use_gpu": -1, + "gold_preproc": False, +} + + +@app.command( + "find-threshold", + context_settings={"allow_extra_args": False, "ignore_unknown_options": True}, +) +def find_threshold_cli( + # fmt: off + model: str = Arg(..., help="Model name or path"), + data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True), + pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"), + threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"), + scores_key: str = Arg(..., help="Metric to optimize"), + n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"), + code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"), + use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"), + gold_preproc: bool = Opt(_DEFAULTS["gold_preproc"], "--gold-preproc", "-G", help="Use gold preprocessing"), + verbose: bool = Opt(False, "--silent", "-V", "-VV", help="Display more information for debugging purposes"), + # fmt: on +): + """ + Runs prediction trials for a trained model with varying tresholds to maximize + the specified metric. The search space for the threshold is traversed linearly + from 0 to 1 in `n_trials` steps. Results are displayed in a table on `stdout` + (the corresponding API call to `spacy.cli.find_threshold.find_threshold()` + returns all results). + + This is applicable only for components whose predictions are influenced by + thresholds - e.g. `textcat_multilabel` and `spancat`, but not `textcat`. Note + that the full path to the corresponding threshold attribute in the config has to + be provided. + + DOCS: https://spacy.io/api/cli#find-threshold + """ + + util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) + import_code(code_path) + find_threshold( + model=model, + data_path=data_path, + pipe_name=pipe_name, + threshold_key=threshold_key, + scores_key=scores_key, + n_trials=n_trials, + use_gpu=use_gpu, + gold_preproc=gold_preproc, + silent=False, + ) + + +def find_threshold( + model: str, + data_path: Path, + pipe_name: str, + threshold_key: str, + scores_key: str, + *, + n_trials: int = _DEFAULTS["n_trials"], # type: ignore + use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore + gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore + silent: bool = True, +) -> Tuple[float, float, Dict[float, float]]: + """ + Runs prediction trials for models with varying tresholds to maximize the specified metric. + model (Union[str, Path]): Pipeline to evaluate. Can be a package or a path to a data directory. + data_path (Path): Path to file with DocBin with docs to use for threshold search. + pipe_name (str): Name of pipe to examine thresholds for. + threshold_key (str): Key of threshold attribute in component's configuration. + scores_key (str): Name of score to metric to optimize. + n_trials (int): Number of trials to determine optimal thresholds. + use_gpu (int): GPU ID or -1 for CPU. + gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the + tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due + to train/test skew. + silent (bool): Whether to print non-error-related output to stdout. + RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all + evaluated thresholds. + """ + + setup_gpu(use_gpu, silent=silent) + data_path = util.ensure_path(data_path) + if not data_path.exists(): + wasabi.msg.fail("Evaluation data not found", data_path, exits=1) + nlp = util.load_model(model) + + if pipe_name not in nlp.component_names: + raise AttributeError( + Errors.E001.format(name=pipe_name, opts=nlp.component_names) + ) + pipe = nlp.get_pipe(pipe_name) + if not hasattr(pipe, "scorer"): + raise AttributeError(Errors.E1045) + + if type(pipe) == TextCategorizer: + wasabi.msg.warn( + "The `textcat` component doesn't use a threshold as it's not applicable to the concept of " + "exclusive classes. All thresholds will yield the same results." + ) + + if not silent: + wasabi.msg.info( + title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} " + f"trials." + ) + + # Load evaluation corpus. + corpus = Corpus(data_path, gold_preproc=gold_preproc) + dev_dataset = list(corpus(nlp)) + config_keys = threshold_key.split(".") + + def set_nested_item( + config: Dict[str, Any], keys: List[str], value: float + ) -> Dict[str, Any]: + """Set item in nested dictionary. Adapted from https://stackoverflow.com/a/54138200. + config (Dict[str, Any]): Configuration dictionary. + keys (List[Any]): Path to value to set. + value (float): Value to set. + RETURNS (Dict[str, Any]): Updated dictionary. + """ + functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value + return config + + def filter_config( + config: Dict[str, Any], keys: List[str], full_key: str + ) -> Dict[str, Any]: + """Filters provided config dictionary so that only the specified keys path remains. + config (Dict[str, Any]): Configuration dictionary. + keys (List[Any]): Path to value to set. + full_key (str): Full user-specified key. + RETURNS (Dict[str, Any]): Filtered dictionary. + """ + if keys[0] not in config: + wasabi.msg.fail( + title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.", + text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: " + f"{list(config.keys())}", + exits=1, + ) + return { + keys[0]: filter_config(config[keys[0]], keys[1:], full_key) + if len(keys) > 1 + else config[keys[0]] + } + + # Evaluate with varying threshold values. + scores: Dict[float, float] = {} + config_keys_full = ["components", pipe_name, *config_keys] + table_col_widths = (10, 10) + thresholds = numpy.linspace(0, 1, n_trials) + print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths)) + for threshold in thresholds: + # Reload pipeline with overrides specifying the new threshold. + nlp = util.load_model( + model, + config=set_nested_item( + filter_config( + nlp.config, config_keys_full, ".".join(config_keys_full) + ).copy(), + config_keys_full, + threshold, + ), + ) + if hasattr(pipe, "cfg"): + setattr( + nlp.get_pipe(pipe_name), + "cfg", + set_nested_item(getattr(pipe, "cfg"), config_keys, threshold), + ) + + eval_scores = nlp.evaluate(dev_dataset) + if scores_key not in eval_scores: + wasabi.msg.fail( + title=f"Failed to look up score `{scores_key}` in evaluation results.", + text=f"Make sure you specified the correct value for `scores_key`. The following scores are " + f"available: {list(eval_scores.keys())}", + exits=1, + ) + scores[threshold] = eval_scores[scores_key] + + if not isinstance(scores[threshold], (float, int)): + wasabi.msg.fail( + f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric " + f"scores.", + exits=1, + ) + print( + wasabi.row( + [round(threshold, 3), round(scores[threshold], 3)], + widths=table_col_widths, + ) + ) + + best_threshold = max(scores.keys(), key=(lambda key: scores[key])) + + # If all scores are identical, emit warning. + if len(set(scores.values())) == 1: + wasabi.msg.warn( + title="All scores are identical. Verify that all settings are correct.", + text="" + if ( + not isinstance(pipe, MultiLabel_TextCategorizer) + or scores_key in ("cats_macro_f", "cats_micro_f") + ) + else "Use `cats_macro_f` or `cats_micro_f` when optimizing the threshold for `textcat_multilabel`.", + ) + + else: + if not silent: + print( + f"\nBest threshold: {round(best_threshold, ndigits=4)} with {scores_key} value of {scores[best_threshold]}." + ) + + return best_threshold, scores[best_threshold], scores diff --git a/spacy/errors.py b/spacy/errors.py index 1d29f0e17..a8de5fb90 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -956,6 +956,7 @@ class Errors(metaclass=ErrorsWithCodes): "sure it's overwritten on the subclass.") E1046 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default " "knowledge base, use `InMemoryLookupKB`.") + E1047 = ("`find_threshold()` only supports components with a `scorer` attribute.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 956bbb72c..0a84c72fd 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -1,7 +1,7 @@ -from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast +from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Optimizer -from thinc.types import Ragged, Ints2d, Floats2d, Ints1d +from thinc.types import Ragged, Ints2d, Floats2d import numpy diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 8225e14f1..1c4d0c98f 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -1,9 +1,10 @@ import os import math +from collections import Counter +from typing import Tuple, List, Dict, Any import pkg_resources -from random import sample -from typing import Counter +import numpy import pytest import srsly from click import NoSuchOption @@ -28,11 +29,12 @@ from spacy.cli.package import get_third_party_dependencies from spacy.cli.package import _is_permitted_package_name from spacy.cli.project.run import _check_requirements from spacy.cli.validate import get_model_pkgs +from spacy.cli.find_threshold import find_threshold from spacy.lang.en import English from spacy.lang.nl import Dutch from spacy.language import Language from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate -from spacy.tokens import Doc +from spacy.tokens import Doc, DocBin from spacy.tokens.span import Span from spacy.training import Example, docs_to_json, offsets_to_biluo_tags from spacy.training.converters import conll_ner_to_docs, conllu_to_docs @@ -859,6 +861,122 @@ def test_span_length_freq_dist_output_must_be_correct(): assert list(span_freqs.keys()) == [3, 1, 4, 5, 2] +def test_cli_find_threshold(capsys): + thresholds = numpy.linspace(0, 1, 10) + + def make_examples(nlp: Language) -> List[Example]: + docs: List[Example] = [] + + for t in [ + ( + "I am angry and confused in the Bank of America.", + { + "cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0}, + "spans": {"sc": [(31, 46, "ORG")]}, + }, + ), + ( + "I am confused but happy in New York.", + { + "cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}, + "spans": {"sc": [(27, 35, "GPE")]}, + }, + ), + ]: + doc = nlp.make_doc(t[0]) + docs.append(Example.from_dict(doc, t[1])) + + return docs + + def init_nlp( + components: Tuple[Tuple[str, Dict[str, Any]], ...] = () + ) -> Tuple[Language, List[Example]]: + new_nlp = English() + new_nlp.add_pipe( # type: ignore + factory_name="textcat_multilabel", + name="tc_multi", + config={"threshold": 0.9}, + ) + + # Append additional components to pipeline. + for cfn, comp_config in components: + new_nlp.add_pipe(cfn, config=comp_config) + + new_examples = make_examples(new_nlp) + new_nlp.initialize(get_examples=lambda: new_examples) + for i in range(5): + new_nlp.update(new_examples) + + return new_nlp, new_examples + + with make_tempdir() as docs_dir: + # Check whether find_threshold() identifies lowest threshold above 0 as (first) ideal threshold, as this matches + # the current model behavior with the examples above. This can break once the model behavior changes and serves + # mostly as a smoke test. + nlp, examples = init_nlp() + DocBin(docs=[example.reference for example in examples]).to_disk( + docs_dir / "docs.spacy" + ) + with make_tempdir() as nlp_dir: + nlp.to_disk(nlp_dir) + res = find_threshold( + model=nlp_dir, + data_path=docs_dir / "docs.spacy", + pipe_name="tc_multi", + threshold_key="threshold", + scores_key="cats_macro_f", + silent=True, + ) + assert res[0] != thresholds[0] + assert thresholds[0] < res[0] < thresholds[9] + assert res[1] == 1.0 + assert res[2][1.0] == 0.0 + + # Test with spancat. + nlp, _ = init_nlp((("spancat", {}),)) + with make_tempdir() as nlp_dir: + nlp.to_disk(nlp_dir) + res = find_threshold( + model=nlp_dir, + data_path=docs_dir / "docs.spacy", + pipe_name="spancat", + threshold_key="threshold", + scores_key="spans_sc_f", + silent=True, + ) + assert res[0] != thresholds[0] + assert thresholds[0] < res[0] < thresholds[8] + assert res[1] >= 0.6 + assert res[2][1.0] == 0.0 + + # Having multiple textcat_multilabel components should work, since the name has to be specified. + nlp, _ = init_nlp((("textcat_multilabel", {}),)) + with make_tempdir() as nlp_dir: + nlp.to_disk(nlp_dir) + assert find_threshold( + model=nlp_dir, + data_path=docs_dir / "docs.spacy", + pipe_name="tc_multi", + threshold_key="threshold", + scores_key="cats_macro_f", + silent=True, + ) + + # Specifying the name of an non-existing pipe should fail. + nlp, _ = init_nlp() + with make_tempdir() as nlp_dir: + nlp.to_disk(nlp_dir) + with pytest.raises(AttributeError): + find_threshold( + model=nlp_dir, + data_path=docs_dir / "docs.spacy", + pipe_name="_", + threshold_key="threshold", + scores_key="cats_macro_f", + silent=True, + ) + + @pytest.mark.parametrize( "reqs,output", [ diff --git a/website/docs/api/cli.md b/website/docs/api/cli.md index 6e581b903..b42ba8a4f 100644 --- a/website/docs/api/cli.md +++ b/website/docs/api/cli.md @@ -12,6 +12,7 @@ menu: - ['train', 'train'] - ['pretrain', 'pretrain'] - ['evaluate', 'evaluate'] + - ['find-threshold', 'find-threshold'] - ['assemble', 'assemble'] - ['package', 'package'] - ['project', 'project'] @@ -1161,6 +1162,46 @@ $ python -m spacy evaluate [model] [data_path] [--output] [--code] [--gold-prepr | `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ | | **CREATES** | Training results and optional metrics and visualizations. | +## find-threshold {#find-threshold new="3.5" tag="command"} + +Runs prediction trials for a trained model with varying tresholds to maximize +the specified metric. The search space for the threshold is traversed linearly +from 0 to 1 in `n_trials` steps. Results are displayed in a table on `stdout` +(the corresponding API call to `spacy.cli.find_threshold.find_threshold()` +returns all results). + +This is applicable only for components whose predictions are influenced by +thresholds - e.g. `textcat_multilabel` and `spancat`, but not `textcat`. Note +that the full path to the corresponding threshold attribute in the config has to +be provided. + +> #### Examples +> +> ```cli +> # For textcat_multilabel: +> $ python -m spacy find-threshold my_nlp data.spacy textcat_multilabel threshold cats_macro_f +> ``` +> +> ```cli +> # For spancat: +> $ python -m spacy find-threshold my_nlp data.spacy spancat threshold spans_sc_f +> ``` + + +| Name | Description | +| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `model` | Pipeline to evaluate. Can be a package or a path to a data directory. ~~str (positional)~~ | +| `data_path` | Path to file with DocBin with docs to use for threshold search. ~~Path (positional)~~ | +| `pipe_name` | Name of pipe to examine thresholds for. ~~str (positional)~~ | +| `threshold_key` | Key of threshold attribute in component's configuration. ~~str (positional)~~ | +| `scores_key` | Name of score to metric to optimize. ~~str (positional)~~ | +| `--n_trials`, `-n` | Number of trials to determine optimal thresholds. ~~int (option)~~ | +| `--code`, `-c` | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-functions) for new architectures. ~~Optional[Path] \(option)~~ | +| `--gpu-id`, `-g` | GPU to use, if any. Defaults to `-1` for CPU. ~~int (option)~~ | +| `--gold-preproc`, `-G` | Use gold preprocessing. ~~bool (flag)~~ | +| `--silent`, `-V`, `-VV` | GPU to use, if any. Defaults to `-1` for CPU. ~~int (option)~~ | +| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ | + ## assemble {#assemble tag="command"} Assemble a pipeline from a config file without additional training. Expects a