Add tests.

This commit is contained in:
Raphael Mitsch 2022-08-08 16:39:22 +02:00
parent 4981700ced
commit 1d0f5d3592
2 changed files with 168 additions and 34 deletions

View File

@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
import logging import logging
from typing import Optional from typing import Optional, Tuple, Union
import numpy import numpy
import wasabi.tables import wasabi.tables
@ -57,23 +57,26 @@ def find_threshold_cli(
def find_threshold( def find_threshold(
model_path: Path, model_path: Union[str, Path],
doc_path: Path, doc_path: Union[str, Path],
*, *,
average: str = _DEFAULTS["average"], average: str = _DEFAULTS["average"],
pipe_name: Optional[str] = _DEFAULTS["pipe_name"], pipe_name: Optional[str] = _DEFAULTS["pipe_name"],
n_trials: int = _DEFAULTS["n_trials"], n_trials: int = _DEFAULTS["n_trials"],
beta: float = _DEFAULTS["beta"], beta: float = _DEFAULTS["beta"],
) -> None: verbose: bool = True,
) -> Tuple[float, float]:
""" """
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric. Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
model_path (Path): Path to file with trained model. model_path (Union[str, Path]): Path to file with trained model.
doc_path (Path): Path to file with DocBin with docs to use for threshold search. doc_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search.
average (str): How to average F-scores across labels. One of ('micro', 'macro'). average (str): How to average F-scores across labels. One of ('micro', 'macro').
pipe_name (Optional[str]): Name of pipe to examine thresholds for. If None, pipe of type MultiLabel_TextCategorizer pipe_name (Optional[str]): Name of pipe to examine thresholds for. If None, pipe of type MultiLabel_TextCategorizer
is seleted. If there are multiple, an error is raised. is seleted. If there are multiple, an error is raised.
n_trials (int): Number of trials to determine optimal thresholds n_trials (int): Number of trials to determine optimal thresholds
beta (float): Beta for F1 calculation. Ignored if different metric is used. beta (float): Beta for F1 calculation. Ignored if different metric is used.
verbose (bool): Whether to print non-error-related output to stdout.
RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score.
""" """
nlp = util.load_model(model_path) nlp = util.load_model(model_path)
@ -90,10 +93,13 @@ def find_threshold(
if pipe_name and _pipe_name == pipe_name: if pipe_name and _pipe_name == pipe_name:
if not isinstance(_pipe, MultiLabel_TextCategorizer): if not isinstance(_pipe, MultiLabel_TextCategorizer):
wasabi.msg.fail( wasabi.msg.fail(
"Specified component {component} is not of type `MultiLabel_TextCategorizer`.", "Specified component '{component}' is not of type `MultiLabel_TextCategorizer`.".format(
component=pipe_name
),
exits=1, exits=1,
) )
pipe = _pipe pipe = _pipe
print(pipe_name, _pipe_name, pipe.labels)
break break
elif pipe_name is None: elif pipe_name is None:
if isinstance(_pipe, MultiLabel_TextCategorizer): if isinstance(_pipe, MultiLabel_TextCategorizer):
@ -116,9 +122,10 @@ def find_threshold(
exits=1, exits=1,
) )
if verbose:
print( print(
f"Searching threshold with the best {average} F-score for pipe '{selected_pipe_name}' with {n_trials} trials" f"Searching threshold with the best {average} F-score for component '{selected_pipe_name}' with {n_trials} "
f" and beta = {beta}." f"trials and beta = {beta}."
) )
thresholds = numpy.linspace(0, 1, n_trials) thresholds = numpy.linspace(0, 1, n_trials)
@ -146,13 +153,15 @@ def find_threshold(
# Collect count stats per threshold value and label. # Collect count stats per threshold value and label.
for threshold in thresholds: for threshold in thresholds:
for label, score in pred_doc.cats.items(): for label, score in pred_doc.cats.items():
if label not in pipe.labels:
continue
label_value = int(score >= threshold) label_value = int(score >= threshold)
if label_value == ref_doc.cats[label] == 1: if label_value == ref_doc.cats[label] == 1:
pred_pos_counts[threshold][True][label] += 1 pred_pos_counts[threshold][True][label] += 1
elif label_value == 1 and ref_doc.cats[label] == 0: elif label_value == 1 and ref_doc.cats[label] == 0:
pred_pos_counts[threshold][False][label] += 1 pred_pos_counts[threshold][False][label] += 1
# Compute f_scores. # Compute F-scores.
for threshold in thresholds: for threshold in thresholds:
for label in ref_pos_counts: for label in ref_pos_counts:
n_pos_preds = ( n_pos_preds = (
@ -188,10 +197,9 @@ def find_threshold(
) / len(ref_pos_counts) ) / len(ref_pos_counts)
best_threshold = max(f_scores, key=f_scores.get) best_threshold = max(f_scores, key=f_scores.get)
if verbose:
print( print(
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}." f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.",
)
print(
wasabi.tables.table( wasabi.tables.table(
data=[ data=[
(threshold, label, f_score) (threshold, label, f_score)
@ -205,3 +213,5 @@ def find_threshold(
header=["Threshold", f"F-Score ({average})"], header=["Threshold", f"F-Score ({average})"],
), ),
) )
return best_threshold, f_scores[best_threshold]

View File

@ -1,8 +1,8 @@
import os import os
import math import math
from random import sample from typing import Counter, Iterable, Tuple, List
from typing import Counter
import numpy
import pytest import pytest
import srsly import srsly
from click import NoSuchOption from click import NoSuchOption
@ -26,19 +26,23 @@ from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
from spacy.cli.package import get_third_party_dependencies from spacy.cli.package import get_third_party_dependencies
from spacy.cli.package import _is_permitted_package_name from spacy.cli.package import _is_permitted_package_name
from spacy.cli.validate import get_model_pkgs from spacy.cli.validate import get_model_pkgs
from spacy.cli.find_threshold import find_threshold
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.nl import Dutch from spacy.lang.nl import Dutch
from spacy.language import Language from spacy.language import Language
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
from spacy.tokens import Doc from spacy.tokens import Doc, DocBin
from spacy.tokens.span import Span from spacy.tokens.span import Span
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
from spacy.training.converters import iob_to_docs from spacy.training.converters import iob_to_docs
from spacy.pipeline import TextCategorizer, Pipe
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
from ..cli.init_pipeline import _init_labels # from ..cli.init_pipeline import _init_labels
from .util import make_tempdir # from .util import make_tempdir
from spacy.cli.init_pipeline import _init_labels
from spacy.tests.util import make_tempdir
@pytest.mark.issue(4665) @pytest.mark.issue(4665)
@ -855,3 +859,123 @@ def test_span_length_freq_dist_output_must_be_correct():
span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold) span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold)
assert sum(span_freqs.values()) >= threshold assert sum(span_freqs.values()) >= threshold
assert list(span_freqs.keys()) == [3, 1, 4, 5, 2] assert list(span_freqs.keys()) == [3, 1, 4, 5, 2]
def test_cli_find_threshold(capsys):
def make_get_examples_multi_label(_nlp: Language) -> List[Example]:
return [
Example.from_dict(_nlp.make_doc(t[0]), t[1])
for t in [
(
"I'm angry and confused",
{"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0}},
),
(
"I'm confused but happy",
{"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}},
),
]
]
def init_nlp(
component_factory_names: Tuple[str] = (),
) -> Tuple[Language, List[Example]]:
_nlp = English()
textcat: TextCategorizer = _nlp.add_pipe(factory_name="textcat_multilabel", name="tc_multi") # type: ignore
textcat.add_label("ANGRY")
textcat.add_label("CONFUSED")
textcat.add_label("HAPPY")
for cfn in component_factory_names:
comp = _nlp.add_pipe(cfn)
if isinstance(comp, TextCategorizer):
comp.add_label("dummy")
_nlp.initialize()
_examples = make_get_examples_multi_label(_nlp)
for i in range(5):
_nlp.update(_examples)
return _nlp, _examples
with make_tempdir() as docs_dir:
# Check whether find_threshold() identifies lowest threshold above 0 as (first) ideal threshold, as this matches
# the current model behavior with the examples above. This can break once the model behavior changes and serves
# mostly as a smoke test.
nlp, examples = init_nlp()
DocBin(docs=[example.reference for example in examples]).to_disk(
docs_dir / "docs"
)
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
assert (
find_threshold(nlp_dir, docs_dir / "docs", verbose=False)[0]
== numpy.linspace(0, 1, 10)[1]
)
# Specifying name of non-MultiLabel_TextCategorizer component should fail.
nlp, _ = init_nlp(("sentencizer",))
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error:
find_threshold(nlp_dir, docs_dir / "docs", pipe_name="sentencizer")
assert error.value.code == 1
# Having multiple textcat_multilabel components without specifying the name should fail.
nlp, _ = init_nlp(("textcat_multilabel",))
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error:
find_threshold(nlp_dir, docs_dir / "docs")
assert error.value.code == 1
# Having multiple textcat_multilabel components should work when specifying the name.
nlp, _ = init_nlp(("textcat_multilabel",))
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
assert (
find_threshold(
nlp_dir, docs_dir / "docs", pipe_name="tc_multi", verbose=False
)[0]
== numpy.linspace(0, 1, 10)[1]
)
# Specifying the name of an non-existing pipe should fail.
nlp, _ = init_nlp()
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error:
find_threshold(nlp_dir, docs_dir / "docs", pipe_name="_")
assert error.value.code == 1
# Using a pipe with no textcat components should fail.
nlp = English()
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error:
find_threshold(nlp_dir, docs_dir / "docs")
assert error.value.code == 1
# Specifying scores not in range 0 <= x <= 1 should fail.
nlp, _ = init_nlp()
DocBin(
docs=[
Example.from_dict(nlp.make_doc(t[0]), t[1]).reference
for t in [
(
"I'm angry and confused",
{"cats": {"ANGRY": 1.0, "CONFUSED": 2.0, "HAPPY": 0.0}},
),
(
"I'm confused but happy",
{"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}},
),
]
]
).to_disk(docs_dir / "docs")
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
with pytest.raises(SystemExit) as error:
find_threshold(nlp_dir, docs_dir / "docs")
assert error.value.code == 1