mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-08 06:04:57 +03:00
Add tests.
This commit is contained in:
parent
4981700ced
commit
1d0f5d3592
|
@ -1,6 +1,6 @@
|
|||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import wasabi.tables
|
||||
|
@ -57,23 +57,26 @@ def find_threshold_cli(
|
|||
|
||||
|
||||
def find_threshold(
|
||||
model_path: Path,
|
||||
doc_path: Path,
|
||||
model_path: Union[str, Path],
|
||||
doc_path: Union[str, Path],
|
||||
*,
|
||||
average: str = _DEFAULTS["average"],
|
||||
pipe_name: Optional[str] = _DEFAULTS["pipe_name"],
|
||||
n_trials: int = _DEFAULTS["n_trials"],
|
||||
beta: float = _DEFAULTS["beta"],
|
||||
) -> None:
|
||||
verbose: bool = True,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
|
||||
model_path (Path): Path to file with trained model.
|
||||
doc_path (Path): Path to file with DocBin with docs to use for threshold search.
|
||||
model_path (Union[str, Path]): Path to file with trained model.
|
||||
doc_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search.
|
||||
average (str): How to average F-scores across labels. One of ('micro', 'macro').
|
||||
pipe_name (Optional[str]): Name of pipe to examine thresholds for. If None, pipe of type MultiLabel_TextCategorizer
|
||||
is seleted. If there are multiple, an error is raised.
|
||||
n_trials (int): Number of trials to determine optimal thresholds
|
||||
beta (float): Beta for F1 calculation. Ignored if different metric is used.
|
||||
verbose (bool): Whether to print non-error-related output to stdout.
|
||||
RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score.
|
||||
"""
|
||||
|
||||
nlp = util.load_model(model_path)
|
||||
|
@ -90,10 +93,13 @@ def find_threshold(
|
|||
if pipe_name and _pipe_name == pipe_name:
|
||||
if not isinstance(_pipe, MultiLabel_TextCategorizer):
|
||||
wasabi.msg.fail(
|
||||
"Specified component {component} is not of type `MultiLabel_TextCategorizer`.",
|
||||
"Specified component '{component}' is not of type `MultiLabel_TextCategorizer`.".format(
|
||||
component=pipe_name
|
||||
),
|
||||
exits=1,
|
||||
)
|
||||
pipe = _pipe
|
||||
print(pipe_name, _pipe_name, pipe.labels)
|
||||
break
|
||||
elif pipe_name is None:
|
||||
if isinstance(_pipe, MultiLabel_TextCategorizer):
|
||||
|
@ -116,9 +122,10 @@ def find_threshold(
|
|||
exits=1,
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f"Searching threshold with the best {average} F-score for pipe '{selected_pipe_name}' with {n_trials} trials"
|
||||
f" and beta = {beta}."
|
||||
f"Searching threshold with the best {average} F-score for component '{selected_pipe_name}' with {n_trials} "
|
||||
f"trials and beta = {beta}."
|
||||
)
|
||||
|
||||
thresholds = numpy.linspace(0, 1, n_trials)
|
||||
|
@ -146,13 +153,15 @@ def find_threshold(
|
|||
# Collect count stats per threshold value and label.
|
||||
for threshold in thresholds:
|
||||
for label, score in pred_doc.cats.items():
|
||||
if label not in pipe.labels:
|
||||
continue
|
||||
label_value = int(score >= threshold)
|
||||
if label_value == ref_doc.cats[label] == 1:
|
||||
pred_pos_counts[threshold][True][label] += 1
|
||||
elif label_value == 1 and ref_doc.cats[label] == 0:
|
||||
pred_pos_counts[threshold][False][label] += 1
|
||||
|
||||
# Compute f_scores.
|
||||
# Compute F-scores.
|
||||
for threshold in thresholds:
|
||||
for label in ref_pos_counts:
|
||||
n_pos_preds = (
|
||||
|
@ -188,10 +197,9 @@ def find_threshold(
|
|||
) / len(ref_pos_counts)
|
||||
|
||||
best_threshold = max(f_scores, key=f_scores.get)
|
||||
if verbose:
|
||||
print(
|
||||
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}."
|
||||
)
|
||||
print(
|
||||
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.",
|
||||
wasabi.tables.table(
|
||||
data=[
|
||||
(threshold, label, f_score)
|
||||
|
@ -205,3 +213,5 @@ def find_threshold(
|
|||
header=["Threshold", f"F-Score ({average})"],
|
||||
),
|
||||
)
|
||||
|
||||
return best_threshold, f_scores[best_threshold]
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import os
|
||||
import math
|
||||
from random import sample
|
||||
from typing import Counter
|
||||
from typing import Counter, Iterable, Tuple, List
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
import srsly
|
||||
from click import NoSuchOption
|
||||
|
@ -26,19 +26,23 @@ from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
|
|||
from spacy.cli.package import get_third_party_dependencies
|
||||
from spacy.cli.package import _is_permitted_package_name
|
||||
from spacy.cli.validate import get_model_pkgs
|
||||
from spacy.cli.find_threshold import find_threshold
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.nl import Dutch
|
||||
from spacy.language import Language
|
||||
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
||||
from spacy.tokens import Doc
|
||||
from spacy.tokens import Doc, DocBin
|
||||
from spacy.tokens.span import Span
|
||||
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
||||
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
||||
from spacy.training.converters import iob_to_docs
|
||||
from spacy.pipeline import TextCategorizer, Pipe
|
||||
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
|
||||
|
||||
from ..cli.init_pipeline import _init_labels
|
||||
from .util import make_tempdir
|
||||
# from ..cli.init_pipeline import _init_labels
|
||||
# from .util import make_tempdir
|
||||
from spacy.cli.init_pipeline import _init_labels
|
||||
from spacy.tests.util import make_tempdir
|
||||
|
||||
|
||||
@pytest.mark.issue(4665)
|
||||
|
@ -855,3 +859,123 @@ def test_span_length_freq_dist_output_must_be_correct():
|
|||
span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold)
|
||||
assert sum(span_freqs.values()) >= threshold
|
||||
assert list(span_freqs.keys()) == [3, 1, 4, 5, 2]
|
||||
|
||||
|
||||
def test_cli_find_threshold(capsys):
|
||||
def make_get_examples_multi_label(_nlp: Language) -> List[Example]:
|
||||
return [
|
||||
Example.from_dict(_nlp.make_doc(t[0]), t[1])
|
||||
for t in [
|
||||
(
|
||||
"I'm angry and confused",
|
||||
{"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0}},
|
||||
),
|
||||
(
|
||||
"I'm confused but happy",
|
||||
{"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}},
|
||||
),
|
||||
]
|
||||
]
|
||||
|
||||
def init_nlp(
|
||||
component_factory_names: Tuple[str] = (),
|
||||
) -> Tuple[Language, List[Example]]:
|
||||
_nlp = English()
|
||||
|
||||
textcat: TextCategorizer = _nlp.add_pipe(factory_name="textcat_multilabel", name="tc_multi") # type: ignore
|
||||
textcat.add_label("ANGRY")
|
||||
textcat.add_label("CONFUSED")
|
||||
textcat.add_label("HAPPY")
|
||||
for cfn in component_factory_names:
|
||||
comp = _nlp.add_pipe(cfn)
|
||||
if isinstance(comp, TextCategorizer):
|
||||
comp.add_label("dummy")
|
||||
|
||||
_nlp.initialize()
|
||||
|
||||
_examples = make_get_examples_multi_label(_nlp)
|
||||
for i in range(5):
|
||||
_nlp.update(_examples)
|
||||
|
||||
return _nlp, _examples
|
||||
|
||||
with make_tempdir() as docs_dir:
|
||||
# Check whether find_threshold() identifies lowest threshold above 0 as (first) ideal threshold, as this matches
|
||||
# the current model behavior with the examples above. This can break once the model behavior changes and serves
|
||||
# mostly as a smoke test.
|
||||
nlp, examples = init_nlp()
|
||||
DocBin(docs=[example.reference for example in examples]).to_disk(
|
||||
docs_dir / "docs"
|
||||
)
|
||||
with make_tempdir() as nlp_dir:
|
||||
nlp.to_disk(nlp_dir)
|
||||
assert (
|
||||
find_threshold(nlp_dir, docs_dir / "docs", verbose=False)[0]
|
||||
== numpy.linspace(0, 1, 10)[1]
|
||||
)
|
||||
|
||||
# Specifying name of non-MultiLabel_TextCategorizer component should fail.
|
||||
nlp, _ = init_nlp(("sentencizer",))
|
||||
with make_tempdir() as nlp_dir:
|
||||
nlp.to_disk(nlp_dir)
|
||||
with pytest.raises(SystemExit) as error:
|
||||
find_threshold(nlp_dir, docs_dir / "docs", pipe_name="sentencizer")
|
||||
assert error.value.code == 1
|
||||
|
||||
# Having multiple textcat_multilabel components without specifying the name should fail.
|
||||
nlp, _ = init_nlp(("textcat_multilabel",))
|
||||
with make_tempdir() as nlp_dir:
|
||||
nlp.to_disk(nlp_dir)
|
||||
with pytest.raises(SystemExit) as error:
|
||||
find_threshold(nlp_dir, docs_dir / "docs")
|
||||
assert error.value.code == 1
|
||||
|
||||
# Having multiple textcat_multilabel components should work when specifying the name.
|
||||
nlp, _ = init_nlp(("textcat_multilabel",))
|
||||
with make_tempdir() as nlp_dir:
|
||||
nlp.to_disk(nlp_dir)
|
||||
assert (
|
||||
find_threshold(
|
||||
nlp_dir, docs_dir / "docs", pipe_name="tc_multi", verbose=False
|
||||
)[0]
|
||||
== numpy.linspace(0, 1, 10)[1]
|
||||
)
|
||||
|
||||
# Specifying the name of an non-existing pipe should fail.
|
||||
nlp, _ = init_nlp()
|
||||
with make_tempdir() as nlp_dir:
|
||||
nlp.to_disk(nlp_dir)
|
||||
with pytest.raises(SystemExit) as error:
|
||||
find_threshold(nlp_dir, docs_dir / "docs", pipe_name="_")
|
||||
assert error.value.code == 1
|
||||
|
||||
# Using a pipe with no textcat components should fail.
|
||||
nlp = English()
|
||||
with make_tempdir() as nlp_dir:
|
||||
nlp.to_disk(nlp_dir)
|
||||
with pytest.raises(SystemExit) as error:
|
||||
find_threshold(nlp_dir, docs_dir / "docs")
|
||||
assert error.value.code == 1
|
||||
|
||||
# Specifying scores not in range 0 <= x <= 1 should fail.
|
||||
nlp, _ = init_nlp()
|
||||
DocBin(
|
||||
docs=[
|
||||
Example.from_dict(nlp.make_doc(t[0]), t[1]).reference
|
||||
for t in [
|
||||
(
|
||||
"I'm angry and confused",
|
||||
{"cats": {"ANGRY": 1.0, "CONFUSED": 2.0, "HAPPY": 0.0}},
|
||||
),
|
||||
(
|
||||
"I'm confused but happy",
|
||||
{"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}},
|
||||
),
|
||||
]
|
||||
]
|
||||
).to_disk(docs_dir / "docs")
|
||||
with make_tempdir() as nlp_dir:
|
||||
nlp.to_disk(nlp_dir)
|
||||
with pytest.raises(SystemExit) as error:
|
||||
find_threshold(nlp_dir, docs_dir / "docs")
|
||||
assert error.value.code == 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user