find-threshold: CLI command for multi-label classifier threshold tuning (#11280)

* Add foundation for find-threshold CLI functionality.

* Finish first draft for find-threshold.

* Add tests.

* Revert adjusted import statements.

* Fix mypy errors.

* Fix imports.

* Harmonize arguments with spacy evaluate command.

* Generalize component and threshold handling. Harmonize arguments with 'spacy evaluate' CLI.

* Fix Spancat test.

* Add beta parameter to Scorer and PRFScore.

* Make beta a component scorer setting.

* Remove beta.

* Update nlp.config (workaround).

* Reload pipeline on threshold change. Adjust tests. Remove confection reference.

* Remove assumption of component being a Pipe object or having a .cfg attribute.

* Adjust test output and reference values.

* Remove beta references. Delete universe.json.

* Reverting unnecessary changes. Removing unused default values. Renaming variables in find-cli tests.

* Update spacy/cli/find_threshold.py

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Remove adding labels in tests.

* Remove unused error

* Undo changes to PRFScorer

* Change default value for n_trials. Log table iteratively.

* Add warnings for pointless applications of find_threshold().

* Fix imports.

* Adjust type check of TextCategorizer to exclude subclasses.

* Change check of if there's only one unique value in scores.

* Update spacy/cli/find_threshold.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Incorporate feedback.

* Fix test issue. Update docstring.

* Update docs & docstring.

* Update spacy/tests/test_cli.py

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Add examples to docs. Rename _nlp to nlp in tests.

* Update spacy/cli/find_threshold.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/cli/find_threshold.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Raphael Mitsch 2022-11-25 11:44:55 +01:00 committed by GitHub
parent dece775279
commit c0fd8a2e71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 399 additions and 5 deletions

View File

@ -27,6 +27,7 @@ from .project.dvc import project_update_dvc # noqa: F401
from .project.push import project_push # noqa: F401
from .project.pull import project_pull # noqa: F401
from .project.document import project_document # noqa: F401
from .find_threshold import find_threshold # noqa: F401
@app.command("link", no_args_is_help=True, deprecated=True, hidden=True)

233
spacy/cli/find_threshold.py Normal file
View File

@ -0,0 +1,233 @@
import functools
import operator
from pathlib import Path
import logging
from typing import Optional, Tuple, Any, Dict, List
import numpy
import wasabi.tables
from ..pipeline import TextCategorizer, MultiLabel_TextCategorizer
from ..errors import Errors
from ..training import Corpus
from ._util import app, Arg, Opt, import_code, setup_gpu
from .. import util
_DEFAULTS = {
"n_trials": 11,
"use_gpu": -1,
"gold_preproc": False,
}
@app.command(
"find-threshold",
context_settings={"allow_extra_args": False, "ignore_unknown_options": True},
)
def find_threshold_cli(
# fmt: off
model: str = Arg(..., help="Model name or path"),
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"),
threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"),
scores_key: str = Arg(..., help="Metric to optimize"),
n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
gold_preproc: bool = Opt(_DEFAULTS["gold_preproc"], "--gold-preproc", "-G", help="Use gold preprocessing"),
verbose: bool = Opt(False, "--silent", "-V", "-VV", help="Display more information for debugging purposes"),
# fmt: on
):
"""
Runs prediction trials for a trained model with varying tresholds to maximize
the specified metric. The search space for the threshold is traversed linearly
from 0 to 1 in `n_trials` steps. Results are displayed in a table on `stdout`
(the corresponding API call to `spacy.cli.find_threshold.find_threshold()`
returns all results).
This is applicable only for components whose predictions are influenced by
thresholds - e.g. `textcat_multilabel` and `spancat`, but not `textcat`. Note
that the full path to the corresponding threshold attribute in the config has to
be provided.
DOCS: https://spacy.io/api/cli#find-threshold
"""
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
import_code(code_path)
find_threshold(
model=model,
data_path=data_path,
pipe_name=pipe_name,
threshold_key=threshold_key,
scores_key=scores_key,
n_trials=n_trials,
use_gpu=use_gpu,
gold_preproc=gold_preproc,
silent=False,
)
def find_threshold(
model: str,
data_path: Path,
pipe_name: str,
threshold_key: str,
scores_key: str,
*,
n_trials: int = _DEFAULTS["n_trials"], # type: ignore
use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore
gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore
silent: bool = True,
) -> Tuple[float, float, Dict[float, float]]:
"""
Runs prediction trials for models with varying tresholds to maximize the specified metric.
model (Union[str, Path]): Pipeline to evaluate. Can be a package or a path to a data directory.
data_path (Path): Path to file with DocBin with docs to use for threshold search.
pipe_name (str): Name of pipe to examine thresholds for.
threshold_key (str): Key of threshold attribute in component's configuration.
scores_key (str): Name of score to metric to optimize.
n_trials (int): Number of trials to determine optimal thresholds.
use_gpu (int): GPU ID or -1 for CPU.
gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
to train/test skew.
silent (bool): Whether to print non-error-related output to stdout.
RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
evaluated thresholds.
"""
setup_gpu(use_gpu, silent=silent)
data_path = util.ensure_path(data_path)
if not data_path.exists():
wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
nlp = util.load_model(model)
if pipe_name not in nlp.component_names:
raise AttributeError(
Errors.E001.format(name=pipe_name, opts=nlp.component_names)
)
pipe = nlp.get_pipe(pipe_name)
if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045)
if type(pipe) == TextCategorizer:
wasabi.msg.warn(
"The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
"exclusive classes. All thresholds will yield the same results."
)
if not silent:
wasabi.msg.info(
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
f"trials."
)
# Load evaluation corpus.
corpus = Corpus(data_path, gold_preproc=gold_preproc)
dev_dataset = list(corpus(nlp))
config_keys = threshold_key.split(".")
def set_nested_item(
config: Dict[str, Any], keys: List[str], value: float
) -> Dict[str, Any]:
"""Set item in nested dictionary. Adapted from https://stackoverflow.com/a/54138200.
config (Dict[str, Any]): Configuration dictionary.
keys (List[Any]): Path to value to set.
value (float): Value to set.
RETURNS (Dict[str, Any]): Updated dictionary.
"""
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
return config
def filter_config(
config: Dict[str, Any], keys: List[str], full_key: str
) -> Dict[str, Any]:
"""Filters provided config dictionary so that only the specified keys path remains.
config (Dict[str, Any]): Configuration dictionary.
keys (List[Any]): Path to value to set.
full_key (str): Full user-specified key.
RETURNS (Dict[str, Any]): Filtered dictionary.
"""
if keys[0] not in config:
wasabi.msg.fail(
title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
f"{list(config.keys())}",
exits=1,
)
return {
keys[0]: filter_config(config[keys[0]], keys[1:], full_key)
if len(keys) > 1
else config[keys[0]]
}
# Evaluate with varying threshold values.
scores: Dict[float, float] = {}
config_keys_full = ["components", pipe_name, *config_keys]
table_col_widths = (10, 10)
thresholds = numpy.linspace(0, 1, n_trials)
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
for threshold in thresholds:
# Reload pipeline with overrides specifying the new threshold.
nlp = util.load_model(
model,
config=set_nested_item(
filter_config(
nlp.config, config_keys_full, ".".join(config_keys_full)
).copy(),
config_keys_full,
threshold,
),
)
if hasattr(pipe, "cfg"):
setattr(
nlp.get_pipe(pipe_name),
"cfg",
set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
)
eval_scores = nlp.evaluate(dev_dataset)
if scores_key not in eval_scores:
wasabi.msg.fail(
title=f"Failed to look up score `{scores_key}` in evaluation results.",
text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
f"available: {list(eval_scores.keys())}",
exits=1,
)
scores[threshold] = eval_scores[scores_key]
if not isinstance(scores[threshold], (float, int)):
wasabi.msg.fail(
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "
f"scores.",
exits=1,
)
print(
wasabi.row(
[round(threshold, 3), round(scores[threshold], 3)],
widths=table_col_widths,
)
)
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
# If all scores are identical, emit warning.
if len(set(scores.values())) == 1:
wasabi.msg.warn(
title="All scores are identical. Verify that all settings are correct.",
text=""
if (
not isinstance(pipe, MultiLabel_TextCategorizer)
or scores_key in ("cats_macro_f", "cats_micro_f")
)
else "Use `cats_macro_f` or `cats_micro_f` when optimizing the threshold for `textcat_multilabel`.",
)
else:
if not silent:
print(
f"\nBest threshold: {round(best_threshold, ndigits=4)} with {scores_key} value of {scores[best_threshold]}."
)
return best_threshold, scores[best_threshold], scores

View File

@ -956,6 +956,7 @@ class Errors(metaclass=ErrorsWithCodes):
"sure it's overwritten on the subclass.")
E1046 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default "
"knowledge base, use `InMemoryLookupKB`.")
E1047 = ("`find_threshold()` only supports components with a `scorer` attribute.")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -1,7 +1,7 @@
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer
from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
from thinc.types import Ragged, Ints2d, Floats2d
import numpy

View File

@ -1,9 +1,10 @@
import os
import math
from collections import Counter
from typing import Tuple, List, Dict, Any
import pkg_resources
from random import sample
from typing import Counter
import numpy
import pytest
import srsly
from click import NoSuchOption
@ -28,11 +29,12 @@ from spacy.cli.package import get_third_party_dependencies
from spacy.cli.package import _is_permitted_package_name
from spacy.cli.project.run import _check_requirements
from spacy.cli.validate import get_model_pkgs
from spacy.cli.find_threshold import find_threshold
from spacy.lang.en import English
from spacy.lang.nl import Dutch
from spacy.language import Language
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
from spacy.tokens import Doc
from spacy.tokens import Doc, DocBin
from spacy.tokens.span import Span
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
@ -859,6 +861,122 @@ def test_span_length_freq_dist_output_must_be_correct():
assert list(span_freqs.keys()) == [3, 1, 4, 5, 2]
def test_cli_find_threshold(capsys):
thresholds = numpy.linspace(0, 1, 10)
def make_examples(nlp: Language) -> List[Example]:
docs: List[Example] = []
for t in [
(
"I am angry and confused in the Bank of America.",
{
"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0},
"spans": {"sc": [(31, 46, "ORG")]},
},
),
(
"I am confused but happy in New York.",
{
"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0},
"spans": {"sc": [(27, 35, "GPE")]},
},
),
]:
doc = nlp.make_doc(t[0])
docs.append(Example.from_dict(doc, t[1]))
return docs
def init_nlp(
components: Tuple[Tuple[str, Dict[str, Any]], ...] = ()
) -> Tuple[Language, List[Example]]:
new_nlp = English()
new_nlp.add_pipe( # type: ignore
factory_name="textcat_multilabel",
name="tc_multi",
config={"threshold": 0.9},
)
# Append additional components to pipeline.
for cfn, comp_config in components:
new_nlp.add_pipe(cfn, config=comp_config)
new_examples = make_examples(new_nlp)
new_nlp.initialize(get_examples=lambda: new_examples)
for i in range(5):
new_nlp.update(new_examples)
return new_nlp, new_examples
with make_tempdir() as docs_dir:
# Check whether find_threshold() identifies lowest threshold above 0 as (first) ideal threshold, as this matches
# the current model behavior with the examples above. This can break once the model behavior changes and serves
# mostly as a smoke test.
nlp, examples = init_nlp()
DocBin(docs=[example.reference for example in examples]).to_disk(
docs_dir / "docs.spacy"
)
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
res = find_threshold(
model=nlp_dir,
data_path=docs_dir / "docs.spacy",
pipe_name="tc_multi",
threshold_key="threshold",
scores_key="cats_macro_f",
silent=True,
)
assert res[0] != thresholds[0]
assert thresholds[0] < res[0] < thresholds[9]
assert res[1] == 1.0
assert res[2][1.0] == 0.0
# Test with spancat.
nlp, _ = init_nlp((("spancat", {}),))
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
res = find_threshold(
model=nlp_dir,
data_path=docs_dir / "docs.spacy",
pipe_name="spancat",
threshold_key="threshold",
scores_key="spans_sc_f",
silent=True,
)
assert res[0] != thresholds[0]
assert thresholds[0] < res[0] < thresholds[8]
assert res[1] >= 0.6
assert res[2][1.0] == 0.0
# Having multiple textcat_multilabel components should work, since the name has to be specified.
nlp, _ = init_nlp((("textcat_multilabel", {}),))
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
assert find_threshold(
model=nlp_dir,
data_path=docs_dir / "docs.spacy",
pipe_name="tc_multi",
threshold_key="threshold",
scores_key="cats_macro_f",
silent=True,
)
# Specifying the name of an non-existing pipe should fail.
nlp, _ = init_nlp()
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(AttributeError):
find_threshold(
model=nlp_dir,
data_path=docs_dir / "docs.spacy",
pipe_name="_",
threshold_key="threshold",
scores_key="cats_macro_f",
silent=True,
)
@pytest.mark.parametrize(
"reqs,output",
[

View File

@ -12,6 +12,7 @@ menu:
- ['train', 'train']
- ['pretrain', 'pretrain']
- ['evaluate', 'evaluate']
- ['find-threshold', 'find-threshold']
- ['assemble', 'assemble']
- ['package', 'package']
- ['project', 'project']
@ -1161,6 +1162,46 @@ $ python -m spacy evaluate [model] [data_path] [--output] [--code] [--gold-prepr
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
| **CREATES** | Training results and optional metrics and visualizations. |
## find-threshold {#find-threshold new="3.5" tag="command"}
Runs prediction trials for a trained model with varying tresholds to maximize
the specified metric. The search space for the threshold is traversed linearly
from 0 to 1 in `n_trials` steps. Results are displayed in a table on `stdout`
(the corresponding API call to `spacy.cli.find_threshold.find_threshold()`
returns all results).
This is applicable only for components whose predictions are influenced by
thresholds - e.g. `textcat_multilabel` and `spancat`, but not `textcat`. Note
that the full path to the corresponding threshold attribute in the config has to
be provided.
> #### Examples
>
> ```cli
> # For textcat_multilabel:
> $ python -m spacy find-threshold my_nlp data.spacy textcat_multilabel threshold cats_macro_f
> ```
>
> ```cli
> # For spancat:
> $ python -m spacy find-threshold my_nlp data.spacy spancat threshold spans_sc_f
> ```
| Name | Description |
| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `model` | Pipeline to evaluate. Can be a package or a path to a data directory. ~~str (positional)~~ |
| `data_path` | Path to file with DocBin with docs to use for threshold search. ~~Path (positional)~~ |
| `pipe_name` | Name of pipe to examine thresholds for. ~~str (positional)~~ |
| `threshold_key` | Key of threshold attribute in component's configuration. ~~str (positional)~~ |
| `scores_key` | Name of score to metric to optimize. ~~str (positional)~~ |
| `--n_trials`, `-n` | Number of trials to determine optimal thresholds. ~~int (option)~~ |
| `--code`, `-c` | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-functions) for new architectures. ~~Optional[Path] \(option)~~ |
| `--gpu-id`, `-g` | GPU to use, if any. Defaults to `-1` for CPU. ~~int (option)~~ |
| `--gold-preproc`, `-G` | Use gold preprocessing. ~~bool (flag)~~ |
| `--silent`, `-V`, `-VV` | GPU to use, if any. Defaults to `-1` for CPU. ~~int (option)~~ |
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
## assemble {#assemble tag="command"}
Assemble a pipeline from a config file without additional training. Expects a