mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
find-threshold: CLI command for multi-label classifier threshold tuning (#11280)
* Add foundation for find-threshold CLI functionality. * Finish first draft for find-threshold. * Add tests. * Revert adjusted import statements. * Fix mypy errors. * Fix imports. * Harmonize arguments with spacy evaluate command. * Generalize component and threshold handling. Harmonize arguments with 'spacy evaluate' CLI. * Fix Spancat test. * Add beta parameter to Scorer and PRFScore. * Make beta a component scorer setting. * Remove beta. * Update nlp.config (workaround). * Reload pipeline on threshold change. Adjust tests. Remove confection reference. * Remove assumption of component being a Pipe object or having a .cfg attribute. * Adjust test output and reference values. * Remove beta references. Delete universe.json. * Reverting unnecessary changes. Removing unused default values. Renaming variables in find-cli tests. * Update spacy/cli/find_threshold.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Remove adding labels in tests. * Remove unused error * Undo changes to PRFScorer * Change default value for n_trials. Log table iteratively. * Add warnings for pointless applications of find_threshold(). * Fix imports. * Adjust type check of TextCategorizer to exclude subclasses. * Change check of if there's only one unique value in scores. * Update spacy/cli/find_threshold.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Incorporate feedback. * Fix test issue. Update docstring. * Update docs & docstring. * Update spacy/tests/test_cli.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Add examples to docs. Rename _nlp to nlp in tests. * Update spacy/cli/find_threshold.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Update spacy/cli/find_threshold.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
dece775279
commit
c0fd8a2e71
|
@ -27,6 +27,7 @@ from .project.dvc import project_update_dvc # noqa: F401
|
||||||
from .project.push import project_push # noqa: F401
|
from .project.push import project_push # noqa: F401
|
||||||
from .project.pull import project_pull # noqa: F401
|
from .project.pull import project_pull # noqa: F401
|
||||||
from .project.document import project_document # noqa: F401
|
from .project.document import project_document # noqa: F401
|
||||||
|
from .find_threshold import find_threshold # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@app.command("link", no_args_is_help=True, deprecated=True, hidden=True)
|
@app.command("link", no_args_is_help=True, deprecated=True, hidden=True)
|
||||||
|
|
233
spacy/cli/find_threshold.py
Normal file
233
spacy/cli/find_threshold.py
Normal file
|
@ -0,0 +1,233 @@
|
||||||
|
import functools
|
||||||
|
import operator
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple, Any, Dict, List
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
import wasabi.tables
|
||||||
|
|
||||||
|
from ..pipeline import TextCategorizer, MultiLabel_TextCategorizer
|
||||||
|
from ..errors import Errors
|
||||||
|
from ..training import Corpus
|
||||||
|
from ._util import app, Arg, Opt, import_code, setup_gpu
|
||||||
|
from .. import util
|
||||||
|
|
||||||
|
_DEFAULTS = {
|
||||||
|
"n_trials": 11,
|
||||||
|
"use_gpu": -1,
|
||||||
|
"gold_preproc": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(
|
||||||
|
"find-threshold",
|
||||||
|
context_settings={"allow_extra_args": False, "ignore_unknown_options": True},
|
||||||
|
)
|
||||||
|
def find_threshold_cli(
|
||||||
|
# fmt: off
|
||||||
|
model: str = Arg(..., help="Model name or path"),
|
||||||
|
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
|
||||||
|
pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"),
|
||||||
|
threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"),
|
||||||
|
scores_key: str = Arg(..., help="Metric to optimize"),
|
||||||
|
n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
|
||||||
|
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||||
|
use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
||||||
|
gold_preproc: bool = Opt(_DEFAULTS["gold_preproc"], "--gold-preproc", "-G", help="Use gold preprocessing"),
|
||||||
|
verbose: bool = Opt(False, "--silent", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||||
|
# fmt: on
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Runs prediction trials for a trained model with varying tresholds to maximize
|
||||||
|
the specified metric. The search space for the threshold is traversed linearly
|
||||||
|
from 0 to 1 in `n_trials` steps. Results are displayed in a table on `stdout`
|
||||||
|
(the corresponding API call to `spacy.cli.find_threshold.find_threshold()`
|
||||||
|
returns all results).
|
||||||
|
|
||||||
|
This is applicable only for components whose predictions are influenced by
|
||||||
|
thresholds - e.g. `textcat_multilabel` and `spancat`, but not `textcat`. Note
|
||||||
|
that the full path to the corresponding threshold attribute in the config has to
|
||||||
|
be provided.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/cli#find-threshold
|
||||||
|
"""
|
||||||
|
|
||||||
|
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||||
|
import_code(code_path)
|
||||||
|
find_threshold(
|
||||||
|
model=model,
|
||||||
|
data_path=data_path,
|
||||||
|
pipe_name=pipe_name,
|
||||||
|
threshold_key=threshold_key,
|
||||||
|
scores_key=scores_key,
|
||||||
|
n_trials=n_trials,
|
||||||
|
use_gpu=use_gpu,
|
||||||
|
gold_preproc=gold_preproc,
|
||||||
|
silent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def find_threshold(
|
||||||
|
model: str,
|
||||||
|
data_path: Path,
|
||||||
|
pipe_name: str,
|
||||||
|
threshold_key: str,
|
||||||
|
scores_key: str,
|
||||||
|
*,
|
||||||
|
n_trials: int = _DEFAULTS["n_trials"], # type: ignore
|
||||||
|
use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore
|
||||||
|
gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore
|
||||||
|
silent: bool = True,
|
||||||
|
) -> Tuple[float, float, Dict[float, float]]:
|
||||||
|
"""
|
||||||
|
Runs prediction trials for models with varying tresholds to maximize the specified metric.
|
||||||
|
model (Union[str, Path]): Pipeline to evaluate. Can be a package or a path to a data directory.
|
||||||
|
data_path (Path): Path to file with DocBin with docs to use for threshold search.
|
||||||
|
pipe_name (str): Name of pipe to examine thresholds for.
|
||||||
|
threshold_key (str): Key of threshold attribute in component's configuration.
|
||||||
|
scores_key (str): Name of score to metric to optimize.
|
||||||
|
n_trials (int): Number of trials to determine optimal thresholds.
|
||||||
|
use_gpu (int): GPU ID or -1 for CPU.
|
||||||
|
gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the
|
||||||
|
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
|
||||||
|
to train/test skew.
|
||||||
|
silent (bool): Whether to print non-error-related output to stdout.
|
||||||
|
RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
|
||||||
|
evaluated thresholds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
setup_gpu(use_gpu, silent=silent)
|
||||||
|
data_path = util.ensure_path(data_path)
|
||||||
|
if not data_path.exists():
|
||||||
|
wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
|
||||||
|
nlp = util.load_model(model)
|
||||||
|
|
||||||
|
if pipe_name not in nlp.component_names:
|
||||||
|
raise AttributeError(
|
||||||
|
Errors.E001.format(name=pipe_name, opts=nlp.component_names)
|
||||||
|
)
|
||||||
|
pipe = nlp.get_pipe(pipe_name)
|
||||||
|
if not hasattr(pipe, "scorer"):
|
||||||
|
raise AttributeError(Errors.E1045)
|
||||||
|
|
||||||
|
if type(pipe) == TextCategorizer:
|
||||||
|
wasabi.msg.warn(
|
||||||
|
"The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
|
||||||
|
"exclusive classes. All thresholds will yield the same results."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not silent:
|
||||||
|
wasabi.msg.info(
|
||||||
|
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
|
||||||
|
f"trials."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load evaluation corpus.
|
||||||
|
corpus = Corpus(data_path, gold_preproc=gold_preproc)
|
||||||
|
dev_dataset = list(corpus(nlp))
|
||||||
|
config_keys = threshold_key.split(".")
|
||||||
|
|
||||||
|
def set_nested_item(
|
||||||
|
config: Dict[str, Any], keys: List[str], value: float
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Set item in nested dictionary. Adapted from https://stackoverflow.com/a/54138200.
|
||||||
|
config (Dict[str, Any]): Configuration dictionary.
|
||||||
|
keys (List[Any]): Path to value to set.
|
||||||
|
value (float): Value to set.
|
||||||
|
RETURNS (Dict[str, Any]): Updated dictionary.
|
||||||
|
"""
|
||||||
|
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
|
||||||
|
return config
|
||||||
|
|
||||||
|
def filter_config(
|
||||||
|
config: Dict[str, Any], keys: List[str], full_key: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Filters provided config dictionary so that only the specified keys path remains.
|
||||||
|
config (Dict[str, Any]): Configuration dictionary.
|
||||||
|
keys (List[Any]): Path to value to set.
|
||||||
|
full_key (str): Full user-specified key.
|
||||||
|
RETURNS (Dict[str, Any]): Filtered dictionary.
|
||||||
|
"""
|
||||||
|
if keys[0] not in config:
|
||||||
|
wasabi.msg.fail(
|
||||||
|
title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
|
||||||
|
text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
|
||||||
|
f"{list(config.keys())}",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
keys[0]: filter_config(config[keys[0]], keys[1:], full_key)
|
||||||
|
if len(keys) > 1
|
||||||
|
else config[keys[0]]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Evaluate with varying threshold values.
|
||||||
|
scores: Dict[float, float] = {}
|
||||||
|
config_keys_full = ["components", pipe_name, *config_keys]
|
||||||
|
table_col_widths = (10, 10)
|
||||||
|
thresholds = numpy.linspace(0, 1, n_trials)
|
||||||
|
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
|
||||||
|
for threshold in thresholds:
|
||||||
|
# Reload pipeline with overrides specifying the new threshold.
|
||||||
|
nlp = util.load_model(
|
||||||
|
model,
|
||||||
|
config=set_nested_item(
|
||||||
|
filter_config(
|
||||||
|
nlp.config, config_keys_full, ".".join(config_keys_full)
|
||||||
|
).copy(),
|
||||||
|
config_keys_full,
|
||||||
|
threshold,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if hasattr(pipe, "cfg"):
|
||||||
|
setattr(
|
||||||
|
nlp.get_pipe(pipe_name),
|
||||||
|
"cfg",
|
||||||
|
set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_scores = nlp.evaluate(dev_dataset)
|
||||||
|
if scores_key not in eval_scores:
|
||||||
|
wasabi.msg.fail(
|
||||||
|
title=f"Failed to look up score `{scores_key}` in evaluation results.",
|
||||||
|
text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
|
||||||
|
f"available: {list(eval_scores.keys())}",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
scores[threshold] = eval_scores[scores_key]
|
||||||
|
|
||||||
|
if not isinstance(scores[threshold], (float, int)):
|
||||||
|
wasabi.msg.fail(
|
||||||
|
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "
|
||||||
|
f"scores.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
wasabi.row(
|
||||||
|
[round(threshold, 3), round(scores[threshold], 3)],
|
||||||
|
widths=table_col_widths,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
|
||||||
|
|
||||||
|
# If all scores are identical, emit warning.
|
||||||
|
if len(set(scores.values())) == 1:
|
||||||
|
wasabi.msg.warn(
|
||||||
|
title="All scores are identical. Verify that all settings are correct.",
|
||||||
|
text=""
|
||||||
|
if (
|
||||||
|
not isinstance(pipe, MultiLabel_TextCategorizer)
|
||||||
|
or scores_key in ("cats_macro_f", "cats_micro_f")
|
||||||
|
)
|
||||||
|
else "Use `cats_macro_f` or `cats_micro_f` when optimizing the threshold for `textcat_multilabel`.",
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if not silent:
|
||||||
|
print(
|
||||||
|
f"\nBest threshold: {round(best_threshold, ndigits=4)} with {scores_key} value of {scores[best_threshold]}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return best_threshold, scores[best_threshold], scores
|
|
@ -956,6 +956,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
"sure it's overwritten on the subclass.")
|
"sure it's overwritten on the subclass.")
|
||||||
E1046 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default "
|
E1046 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default "
|
||||||
"knowledge base, use `InMemoryLookupKB`.")
|
"knowledge base, use `InMemoryLookupKB`.")
|
||||||
|
E1047 = ("`find_threshold()` only supports components with a `scorer` attribute.")
|
||||||
|
|
||||||
|
|
||||||
# Deprecated model shortcuts, only used in errors and warnings
|
# Deprecated model shortcuts, only used in errors and warnings
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
|
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any
|
||||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||||
from thinc.api import Optimizer
|
from thinc.api import Optimizer
|
||||||
from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
|
from thinc.types import Ragged, Ints2d, Floats2d
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
|
from collections import Counter
|
||||||
|
from typing import Tuple, List, Dict, Any
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
from random import sample
|
|
||||||
from typing import Counter
|
|
||||||
|
|
||||||
|
import numpy
|
||||||
import pytest
|
import pytest
|
||||||
import srsly
|
import srsly
|
||||||
from click import NoSuchOption
|
from click import NoSuchOption
|
||||||
|
@ -28,11 +29,12 @@ from spacy.cli.package import get_third_party_dependencies
|
||||||
from spacy.cli.package import _is_permitted_package_name
|
from spacy.cli.package import _is_permitted_package_name
|
||||||
from spacy.cli.project.run import _check_requirements
|
from spacy.cli.project.run import _check_requirements
|
||||||
from spacy.cli.validate import get_model_pkgs
|
from spacy.cli.validate import get_model_pkgs
|
||||||
|
from spacy.cli.find_threshold import find_threshold
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.lang.nl import Dutch
|
from spacy.lang.nl import Dutch
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc, DocBin
|
||||||
from spacy.tokens.span import Span
|
from spacy.tokens.span import Span
|
||||||
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
||||||
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
||||||
|
@ -859,6 +861,122 @@ def test_span_length_freq_dist_output_must_be_correct():
|
||||||
assert list(span_freqs.keys()) == [3, 1, 4, 5, 2]
|
assert list(span_freqs.keys()) == [3, 1, 4, 5, 2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_find_threshold(capsys):
|
||||||
|
thresholds = numpy.linspace(0, 1, 10)
|
||||||
|
|
||||||
|
def make_examples(nlp: Language) -> List[Example]:
|
||||||
|
docs: List[Example] = []
|
||||||
|
|
||||||
|
for t in [
|
||||||
|
(
|
||||||
|
"I am angry and confused in the Bank of America.",
|
||||||
|
{
|
||||||
|
"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0},
|
||||||
|
"spans": {"sc": [(31, 46, "ORG")]},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"I am confused but happy in New York.",
|
||||||
|
{
|
||||||
|
"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0},
|
||||||
|
"spans": {"sc": [(27, 35, "GPE")]},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]:
|
||||||
|
doc = nlp.make_doc(t[0])
|
||||||
|
docs.append(Example.from_dict(doc, t[1]))
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def init_nlp(
|
||||||
|
components: Tuple[Tuple[str, Dict[str, Any]], ...] = ()
|
||||||
|
) -> Tuple[Language, List[Example]]:
|
||||||
|
new_nlp = English()
|
||||||
|
new_nlp.add_pipe( # type: ignore
|
||||||
|
factory_name="textcat_multilabel",
|
||||||
|
name="tc_multi",
|
||||||
|
config={"threshold": 0.9},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append additional components to pipeline.
|
||||||
|
for cfn, comp_config in components:
|
||||||
|
new_nlp.add_pipe(cfn, config=comp_config)
|
||||||
|
|
||||||
|
new_examples = make_examples(new_nlp)
|
||||||
|
new_nlp.initialize(get_examples=lambda: new_examples)
|
||||||
|
for i in range(5):
|
||||||
|
new_nlp.update(new_examples)
|
||||||
|
|
||||||
|
return new_nlp, new_examples
|
||||||
|
|
||||||
|
with make_tempdir() as docs_dir:
|
||||||
|
# Check whether find_threshold() identifies lowest threshold above 0 as (first) ideal threshold, as this matches
|
||||||
|
# the current model behavior with the examples above. This can break once the model behavior changes and serves
|
||||||
|
# mostly as a smoke test.
|
||||||
|
nlp, examples = init_nlp()
|
||||||
|
DocBin(docs=[example.reference for example in examples]).to_disk(
|
||||||
|
docs_dir / "docs.spacy"
|
||||||
|
)
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
res = find_threshold(
|
||||||
|
model=nlp_dir,
|
||||||
|
data_path=docs_dir / "docs.spacy",
|
||||||
|
pipe_name="tc_multi",
|
||||||
|
threshold_key="threshold",
|
||||||
|
scores_key="cats_macro_f",
|
||||||
|
silent=True,
|
||||||
|
)
|
||||||
|
assert res[0] != thresholds[0]
|
||||||
|
assert thresholds[0] < res[0] < thresholds[9]
|
||||||
|
assert res[1] == 1.0
|
||||||
|
assert res[2][1.0] == 0.0
|
||||||
|
|
||||||
|
# Test with spancat.
|
||||||
|
nlp, _ = init_nlp((("spancat", {}),))
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
res = find_threshold(
|
||||||
|
model=nlp_dir,
|
||||||
|
data_path=docs_dir / "docs.spacy",
|
||||||
|
pipe_name="spancat",
|
||||||
|
threshold_key="threshold",
|
||||||
|
scores_key="spans_sc_f",
|
||||||
|
silent=True,
|
||||||
|
)
|
||||||
|
assert res[0] != thresholds[0]
|
||||||
|
assert thresholds[0] < res[0] < thresholds[8]
|
||||||
|
assert res[1] >= 0.6
|
||||||
|
assert res[2][1.0] == 0.0
|
||||||
|
|
||||||
|
# Having multiple textcat_multilabel components should work, since the name has to be specified.
|
||||||
|
nlp, _ = init_nlp((("textcat_multilabel", {}),))
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
assert find_threshold(
|
||||||
|
model=nlp_dir,
|
||||||
|
data_path=docs_dir / "docs.spacy",
|
||||||
|
pipe_name="tc_multi",
|
||||||
|
threshold_key="threshold",
|
||||||
|
scores_key="cats_macro_f",
|
||||||
|
silent=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Specifying the name of an non-existing pipe should fail.
|
||||||
|
nlp, _ = init_nlp()
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
find_threshold(
|
||||||
|
model=nlp_dir,
|
||||||
|
data_path=docs_dir / "docs.spacy",
|
||||||
|
pipe_name="_",
|
||||||
|
threshold_key="threshold",
|
||||||
|
scores_key="cats_macro_f",
|
||||||
|
silent=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"reqs,output",
|
"reqs,output",
|
||||||
[
|
[
|
||||||
|
|
|
@ -12,6 +12,7 @@ menu:
|
||||||
- ['train', 'train']
|
- ['train', 'train']
|
||||||
- ['pretrain', 'pretrain']
|
- ['pretrain', 'pretrain']
|
||||||
- ['evaluate', 'evaluate']
|
- ['evaluate', 'evaluate']
|
||||||
|
- ['find-threshold', 'find-threshold']
|
||||||
- ['assemble', 'assemble']
|
- ['assemble', 'assemble']
|
||||||
- ['package', 'package']
|
- ['package', 'package']
|
||||||
- ['project', 'project']
|
- ['project', 'project']
|
||||||
|
@ -1161,6 +1162,46 @@ $ python -m spacy evaluate [model] [data_path] [--output] [--code] [--gold-prepr
|
||||||
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||||
| **CREATES** | Training results and optional metrics and visualizations. |
|
| **CREATES** | Training results and optional metrics and visualizations. |
|
||||||
|
|
||||||
|
## find-threshold {#find-threshold new="3.5" tag="command"}
|
||||||
|
|
||||||
|
Runs prediction trials for a trained model with varying tresholds to maximize
|
||||||
|
the specified metric. The search space for the threshold is traversed linearly
|
||||||
|
from 0 to 1 in `n_trials` steps. Results are displayed in a table on `stdout`
|
||||||
|
(the corresponding API call to `spacy.cli.find_threshold.find_threshold()`
|
||||||
|
returns all results).
|
||||||
|
|
||||||
|
This is applicable only for components whose predictions are influenced by
|
||||||
|
thresholds - e.g. `textcat_multilabel` and `spancat`, but not `textcat`. Note
|
||||||
|
that the full path to the corresponding threshold attribute in the config has to
|
||||||
|
be provided.
|
||||||
|
|
||||||
|
> #### Examples
|
||||||
|
>
|
||||||
|
> ```cli
|
||||||
|
> # For textcat_multilabel:
|
||||||
|
> $ python -m spacy find-threshold my_nlp data.spacy textcat_multilabel threshold cats_macro_f
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> ```cli
|
||||||
|
> # For spancat:
|
||||||
|
> $ python -m spacy find-threshold my_nlp data.spacy spancat threshold spans_sc_f
|
||||||
|
> ```
|
||||||
|
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
|
| `model` | Pipeline to evaluate. Can be a package or a path to a data directory. ~~str (positional)~~ |
|
||||||
|
| `data_path` | Path to file with DocBin with docs to use for threshold search. ~~Path (positional)~~ |
|
||||||
|
| `pipe_name` | Name of pipe to examine thresholds for. ~~str (positional)~~ |
|
||||||
|
| `threshold_key` | Key of threshold attribute in component's configuration. ~~str (positional)~~ |
|
||||||
|
| `scores_key` | Name of score to metric to optimize. ~~str (positional)~~ |
|
||||||
|
| `--n_trials`, `-n` | Number of trials to determine optimal thresholds. ~~int (option)~~ |
|
||||||
|
| `--code`, `-c` | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-functions) for new architectures. ~~Optional[Path] \(option)~~ |
|
||||||
|
| `--gpu-id`, `-g` | GPU to use, if any. Defaults to `-1` for CPU. ~~int (option)~~ |
|
||||||
|
| `--gold-preproc`, `-G` | Use gold preprocessing. ~~bool (flag)~~ |
|
||||||
|
| `--silent`, `-V`, `-VV` | GPU to use, if any. Defaults to `-1` for CPU. ~~int (option)~~ |
|
||||||
|
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||||
|
|
||||||
## assemble {#assemble tag="command"}
|
## assemble {#assemble tag="command"}
|
||||||
|
|
||||||
Assemble a pipeline from a config file without additional training. Expects a
|
Assemble a pipeline from a config file without additional training. Expects a
|
||||||
|
|
Loading…
Reference in New Issue
Block a user