Add tests.

This commit is contained in:
Raphael Mitsch 2022-08-08 16:39:22 +02:00
parent 4981700ced
commit 1d0f5d3592
2 changed files with 168 additions and 34 deletions

View File

@ -1,6 +1,6 @@
from pathlib import Path
import logging
from typing import Optional
from typing import Optional, Tuple, Union
import numpy
import wasabi.tables
@ -57,23 +57,26 @@ def find_threshold_cli(
def find_threshold(
model_path: Path,
doc_path: Path,
model_path: Union[str, Path],
doc_path: Union[str, Path],
*,
average: str = _DEFAULTS["average"],
pipe_name: Optional[str] = _DEFAULTS["pipe_name"],
n_trials: int = _DEFAULTS["n_trials"],
beta: float = _DEFAULTS["beta"],
) -> None:
verbose: bool = True,
) -> Tuple[float, float]:
"""
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
model_path (Path): Path to file with trained model.
doc_path (Path): Path to file with DocBin with docs to use for threshold search.
model_path (Union[str, Path]): Path to file with trained model.
doc_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search.
average (str): How to average F-scores across labels. One of ('micro', 'macro').
pipe_name (Optional[str]): Name of pipe to examine thresholds for. If None, pipe of type MultiLabel_TextCategorizer
is seleted. If there are multiple, an error is raised.
n_trials (int): Number of trials to determine optimal thresholds
beta (float): Beta for F1 calculation. Ignored if different metric is used.
verbose (bool): Whether to print non-error-related output to stdout.
RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score.
"""
nlp = util.load_model(model_path)
@ -90,10 +93,13 @@ def find_threshold(
if pipe_name and _pipe_name == pipe_name:
if not isinstance(_pipe, MultiLabel_TextCategorizer):
wasabi.msg.fail(
"Specified component {component} is not of type `MultiLabel_TextCategorizer`.",
"Specified component '{component}' is not of type `MultiLabel_TextCategorizer`.".format(
component=pipe_name
),
exits=1,
)
pipe = _pipe
print(pipe_name, _pipe_name, pipe.labels)
break
elif pipe_name is None:
if isinstance(_pipe, MultiLabel_TextCategorizer):
@ -116,10 +122,11 @@ def find_threshold(
exits=1,
)
print(
f"Searching threshold with the best {average} F-score for pipe '{selected_pipe_name}' with {n_trials} trials"
f" and beta = {beta}."
)
if verbose:
print(
f"Searching threshold with the best {average} F-score for component '{selected_pipe_name}' with {n_trials} "
f"trials and beta = {beta}."
)
thresholds = numpy.linspace(0, 1, n_trials)
ref_pos_counts = {label: 0 for label in pipe.labels}
@ -146,13 +153,15 @@ def find_threshold(
# Collect count stats per threshold value and label.
for threshold in thresholds:
for label, score in pred_doc.cats.items():
if label not in pipe.labels:
continue
label_value = int(score >= threshold)
if label_value == ref_doc.cats[label] == 1:
pred_pos_counts[threshold][True][label] += 1
elif label_value == 1 and ref_doc.cats[label] == 0:
pred_pos_counts[threshold][False][label] += 1
# Compute f_scores.
# Compute F-scores.
for threshold in thresholds:
for label in ref_pos_counts:
n_pos_preds = (
@ -188,20 +197,21 @@ def find_threshold(
) / len(ref_pos_counts)
best_threshold = max(f_scores, key=f_scores.get)
print(
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}."
)
print(
wasabi.tables.table(
data=[
(threshold, label, f_score)
for threshold, label_f_scores in f_scores_per_label.items()
for label, f_score in label_f_scores.items()
],
header=["Threshold", "Label", "F-Score"],
),
wasabi.tables.table(
data=[(threshold, f_score) for threshold, f_score in f_scores.items()],
header=["Threshold", f"F-Score ({average})"],
),
)
if verbose:
print(
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.",
wasabi.tables.table(
data=[
(threshold, label, f_score)
for threshold, label_f_scores in f_scores_per_label.items()
for label, f_score in label_f_scores.items()
],
header=["Threshold", "Label", "F-Score"],
),
wasabi.tables.table(
data=[(threshold, f_score) for threshold, f_score in f_scores.items()],
header=["Threshold", f"F-Score ({average})"],
),
)
return best_threshold, f_scores[best_threshold]

View File

@ -1,8 +1,8 @@
import os
import math
from random import sample
from typing import Counter
from typing import Counter, Iterable, Tuple, List
import numpy
import pytest
import srsly
from click import NoSuchOption
@ -26,19 +26,23 @@ from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
from spacy.cli.package import get_third_party_dependencies
from spacy.cli.package import _is_permitted_package_name
from spacy.cli.validate import get_model_pkgs
from spacy.cli.find_threshold import find_threshold
from spacy.lang.en import English
from spacy.lang.nl import Dutch
from spacy.language import Language
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
from spacy.tokens import Doc
from spacy.tokens import Doc, DocBin
from spacy.tokens.span import Span
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
from spacy.training.converters import iob_to_docs
from spacy.pipeline import TextCategorizer, Pipe
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
from ..cli.init_pipeline import _init_labels
from .util import make_tempdir
# from ..cli.init_pipeline import _init_labels
# from .util import make_tempdir
from spacy.cli.init_pipeline import _init_labels
from spacy.tests.util import make_tempdir
@pytest.mark.issue(4665)
@ -855,3 +859,123 @@ def test_span_length_freq_dist_output_must_be_correct():
span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold)
assert sum(span_freqs.values()) >= threshold
assert list(span_freqs.keys()) == [3, 1, 4, 5, 2]
def test_cli_find_threshold(capsys):
def make_get_examples_multi_label(_nlp: Language) -> List[Example]:
return [
Example.from_dict(_nlp.make_doc(t[0]), t[1])
for t in [
(
"I'm angry and confused",
{"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0}},
),
(
"I'm confused but happy",
{"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}},
),
]
]
def init_nlp(
component_factory_names: Tuple[str] = (),
) -> Tuple[Language, List[Example]]:
_nlp = English()
textcat: TextCategorizer = _nlp.add_pipe(factory_name="textcat_multilabel", name="tc_multi") # type: ignore
textcat.add_label("ANGRY")
textcat.add_label("CONFUSED")
textcat.add_label("HAPPY")
for cfn in component_factory_names:
comp = _nlp.add_pipe(cfn)
if isinstance(comp, TextCategorizer):
comp.add_label("dummy")
_nlp.initialize()
_examples = make_get_examples_multi_label(_nlp)
for i in range(5):
_nlp.update(_examples)
return _nlp, _examples
with make_tempdir() as docs_dir:
# Check whether find_threshold() identifies lowest threshold above 0 as (first) ideal threshold, as this matches
# the current model behavior with the examples above. This can break once the model behavior changes and serves
# mostly as a smoke test.
nlp, examples = init_nlp()
DocBin(docs=[example.reference for example in examples]).to_disk(
docs_dir / "docs"
)
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
assert (
find_threshold(nlp_dir, docs_dir / "docs", verbose=False)[0]
== numpy.linspace(0, 1, 10)[1]
)
# Specifying name of non-MultiLabel_TextCategorizer component should fail.
nlp, _ = init_nlp(("sentencizer",))
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error:
find_threshold(nlp_dir, docs_dir / "docs", pipe_name="sentencizer")
assert error.value.code == 1
# Having multiple textcat_multilabel components without specifying the name should fail.
nlp, _ = init_nlp(("textcat_multilabel",))
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error:
find_threshold(nlp_dir, docs_dir / "docs")
assert error.value.code == 1
# Having multiple textcat_multilabel components should work when specifying the name.
nlp, _ = init_nlp(("textcat_multilabel",))
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
assert (
find_threshold(
nlp_dir, docs_dir / "docs", pipe_name="tc_multi", verbose=False
)[0]
== numpy.linspace(0, 1, 10)[1]
)
# Specifying the name of an non-existing pipe should fail.
nlp, _ = init_nlp()
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error:
find_threshold(nlp_dir, docs_dir / "docs", pipe_name="_")
assert error.value.code == 1
# Using a pipe with no textcat components should fail.
nlp = English()
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error:
find_threshold(nlp_dir, docs_dir / "docs")
assert error.value.code == 1
# Specifying scores not in range 0 <= x <= 1 should fail.
nlp, _ = init_nlp()
DocBin(
docs=[
Example.from_dict(nlp.make_doc(t[0]), t[1]).reference
for t in [
(
"I'm angry and confused",
{"cats": {"ANGRY": 1.0, "CONFUSED": 2.0, "HAPPY": 0.0}},
),
(
"I'm confused but happy",
{"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}},
),
]
]
).to_disk(docs_dir / "docs")
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error:
find_threshold(nlp_dir, docs_dir / "docs")
assert error.value.code == 1