mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-08 14:14:57 +03:00
Add tests.
This commit is contained in:
parent
4981700ced
commit
1d0f5d3592
|
@ -1,6 +1,6 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import wasabi.tables
|
import wasabi.tables
|
||||||
|
@ -57,23 +57,26 @@ def find_threshold_cli(
|
||||||
|
|
||||||
|
|
||||||
def find_threshold(
|
def find_threshold(
|
||||||
model_path: Path,
|
model_path: Union[str, Path],
|
||||||
doc_path: Path,
|
doc_path: Union[str, Path],
|
||||||
*,
|
*,
|
||||||
average: str = _DEFAULTS["average"],
|
average: str = _DEFAULTS["average"],
|
||||||
pipe_name: Optional[str] = _DEFAULTS["pipe_name"],
|
pipe_name: Optional[str] = _DEFAULTS["pipe_name"],
|
||||||
n_trials: int = _DEFAULTS["n_trials"],
|
n_trials: int = _DEFAULTS["n_trials"],
|
||||||
beta: float = _DEFAULTS["beta"],
|
beta: float = _DEFAULTS["beta"],
|
||||||
) -> None:
|
verbose: bool = True,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
"""
|
"""
|
||||||
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
|
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
|
||||||
model_path (Path): Path to file with trained model.
|
model_path (Union[str, Path]): Path to file with trained model.
|
||||||
doc_path (Path): Path to file with DocBin with docs to use for threshold search.
|
doc_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search.
|
||||||
average (str): How to average F-scores across labels. One of ('micro', 'macro').
|
average (str): How to average F-scores across labels. One of ('micro', 'macro').
|
||||||
pipe_name (Optional[str]): Name of pipe to examine thresholds for. If None, pipe of type MultiLabel_TextCategorizer
|
pipe_name (Optional[str]): Name of pipe to examine thresholds for. If None, pipe of type MultiLabel_TextCategorizer
|
||||||
is seleted. If there are multiple, an error is raised.
|
is seleted. If there are multiple, an error is raised.
|
||||||
n_trials (int): Number of trials to determine optimal thresholds
|
n_trials (int): Number of trials to determine optimal thresholds
|
||||||
beta (float): Beta for F1 calculation. Ignored if different metric is used.
|
beta (float): Beta for F1 calculation. Ignored if different metric is used.
|
||||||
|
verbose (bool): Whether to print non-error-related output to stdout.
|
||||||
|
RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nlp = util.load_model(model_path)
|
nlp = util.load_model(model_path)
|
||||||
|
@ -90,10 +93,13 @@ def find_threshold(
|
||||||
if pipe_name and _pipe_name == pipe_name:
|
if pipe_name and _pipe_name == pipe_name:
|
||||||
if not isinstance(_pipe, MultiLabel_TextCategorizer):
|
if not isinstance(_pipe, MultiLabel_TextCategorizer):
|
||||||
wasabi.msg.fail(
|
wasabi.msg.fail(
|
||||||
"Specified component {component} is not of type `MultiLabel_TextCategorizer`.",
|
"Specified component '{component}' is not of type `MultiLabel_TextCategorizer`.".format(
|
||||||
|
component=pipe_name
|
||||||
|
),
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
pipe = _pipe
|
pipe = _pipe
|
||||||
|
print(pipe_name, _pipe_name, pipe.labels)
|
||||||
break
|
break
|
||||||
elif pipe_name is None:
|
elif pipe_name is None:
|
||||||
if isinstance(_pipe, MultiLabel_TextCategorizer):
|
if isinstance(_pipe, MultiLabel_TextCategorizer):
|
||||||
|
@ -116,10 +122,11 @@ def find_threshold(
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
if verbose:
|
||||||
f"Searching threshold with the best {average} F-score for pipe '{selected_pipe_name}' with {n_trials} trials"
|
print(
|
||||||
f" and beta = {beta}."
|
f"Searching threshold with the best {average} F-score for component '{selected_pipe_name}' with {n_trials} "
|
||||||
)
|
f"trials and beta = {beta}."
|
||||||
|
)
|
||||||
|
|
||||||
thresholds = numpy.linspace(0, 1, n_trials)
|
thresholds = numpy.linspace(0, 1, n_trials)
|
||||||
ref_pos_counts = {label: 0 for label in pipe.labels}
|
ref_pos_counts = {label: 0 for label in pipe.labels}
|
||||||
|
@ -146,13 +153,15 @@ def find_threshold(
|
||||||
# Collect count stats per threshold value and label.
|
# Collect count stats per threshold value and label.
|
||||||
for threshold in thresholds:
|
for threshold in thresholds:
|
||||||
for label, score in pred_doc.cats.items():
|
for label, score in pred_doc.cats.items():
|
||||||
|
if label not in pipe.labels:
|
||||||
|
continue
|
||||||
label_value = int(score >= threshold)
|
label_value = int(score >= threshold)
|
||||||
if label_value == ref_doc.cats[label] == 1:
|
if label_value == ref_doc.cats[label] == 1:
|
||||||
pred_pos_counts[threshold][True][label] += 1
|
pred_pos_counts[threshold][True][label] += 1
|
||||||
elif label_value == 1 and ref_doc.cats[label] == 0:
|
elif label_value == 1 and ref_doc.cats[label] == 0:
|
||||||
pred_pos_counts[threshold][False][label] += 1
|
pred_pos_counts[threshold][False][label] += 1
|
||||||
|
|
||||||
# Compute f_scores.
|
# Compute F-scores.
|
||||||
for threshold in thresholds:
|
for threshold in thresholds:
|
||||||
for label in ref_pos_counts:
|
for label in ref_pos_counts:
|
||||||
n_pos_preds = (
|
n_pos_preds = (
|
||||||
|
@ -188,20 +197,21 @@ def find_threshold(
|
||||||
) / len(ref_pos_counts)
|
) / len(ref_pos_counts)
|
||||||
|
|
||||||
best_threshold = max(f_scores, key=f_scores.get)
|
best_threshold = max(f_scores, key=f_scores.get)
|
||||||
print(
|
if verbose:
|
||||||
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}."
|
print(
|
||||||
)
|
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.",
|
||||||
print(
|
wasabi.tables.table(
|
||||||
wasabi.tables.table(
|
data=[
|
||||||
data=[
|
(threshold, label, f_score)
|
||||||
(threshold, label, f_score)
|
for threshold, label_f_scores in f_scores_per_label.items()
|
||||||
for threshold, label_f_scores in f_scores_per_label.items()
|
for label, f_score in label_f_scores.items()
|
||||||
for label, f_score in label_f_scores.items()
|
],
|
||||||
],
|
header=["Threshold", "Label", "F-Score"],
|
||||||
header=["Threshold", "Label", "F-Score"],
|
),
|
||||||
),
|
wasabi.tables.table(
|
||||||
wasabi.tables.table(
|
data=[(threshold, f_score) for threshold, f_score in f_scores.items()],
|
||||||
data=[(threshold, f_score) for threshold, f_score in f_scores.items()],
|
header=["Threshold", f"F-Score ({average})"],
|
||||||
header=["Threshold", f"F-Score ({average})"],
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
return best_threshold, f_scores[best_threshold]
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
from random import sample
|
from typing import Counter, Iterable, Tuple, List
|
||||||
from typing import Counter
|
|
||||||
|
|
||||||
|
import numpy
|
||||||
import pytest
|
import pytest
|
||||||
import srsly
|
import srsly
|
||||||
from click import NoSuchOption
|
from click import NoSuchOption
|
||||||
|
@ -26,19 +26,23 @@ from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
|
||||||
from spacy.cli.package import get_third_party_dependencies
|
from spacy.cli.package import get_third_party_dependencies
|
||||||
from spacy.cli.package import _is_permitted_package_name
|
from spacy.cli.package import _is_permitted_package_name
|
||||||
from spacy.cli.validate import get_model_pkgs
|
from spacy.cli.validate import get_model_pkgs
|
||||||
|
from spacy.cli.find_threshold import find_threshold
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.lang.nl import Dutch
|
from spacy.lang.nl import Dutch
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc, DocBin
|
||||||
from spacy.tokens.span import Span
|
from spacy.tokens.span import Span
|
||||||
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
||||||
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
||||||
from spacy.training.converters import iob_to_docs
|
from spacy.training.converters import iob_to_docs
|
||||||
|
from spacy.pipeline import TextCategorizer, Pipe
|
||||||
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
|
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
|
||||||
|
|
||||||
from ..cli.init_pipeline import _init_labels
|
# from ..cli.init_pipeline import _init_labels
|
||||||
from .util import make_tempdir
|
# from .util import make_tempdir
|
||||||
|
from spacy.cli.init_pipeline import _init_labels
|
||||||
|
from spacy.tests.util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.issue(4665)
|
@pytest.mark.issue(4665)
|
||||||
|
@ -855,3 +859,123 @@ def test_span_length_freq_dist_output_must_be_correct():
|
||||||
span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold)
|
span_freqs = _get_spans_length_freq_dist(sample_span_lengths, threshold)
|
||||||
assert sum(span_freqs.values()) >= threshold
|
assert sum(span_freqs.values()) >= threshold
|
||||||
assert list(span_freqs.keys()) == [3, 1, 4, 5, 2]
|
assert list(span_freqs.keys()) == [3, 1, 4, 5, 2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_find_threshold(capsys):
|
||||||
|
def make_get_examples_multi_label(_nlp: Language) -> List[Example]:
|
||||||
|
return [
|
||||||
|
Example.from_dict(_nlp.make_doc(t[0]), t[1])
|
||||||
|
for t in [
|
||||||
|
(
|
||||||
|
"I'm angry and confused",
|
||||||
|
{"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0}},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"I'm confused but happy",
|
||||||
|
{"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
def init_nlp(
|
||||||
|
component_factory_names: Tuple[str] = (),
|
||||||
|
) -> Tuple[Language, List[Example]]:
|
||||||
|
_nlp = English()
|
||||||
|
|
||||||
|
textcat: TextCategorizer = _nlp.add_pipe(factory_name="textcat_multilabel", name="tc_multi") # type: ignore
|
||||||
|
textcat.add_label("ANGRY")
|
||||||
|
textcat.add_label("CONFUSED")
|
||||||
|
textcat.add_label("HAPPY")
|
||||||
|
for cfn in component_factory_names:
|
||||||
|
comp = _nlp.add_pipe(cfn)
|
||||||
|
if isinstance(comp, TextCategorizer):
|
||||||
|
comp.add_label("dummy")
|
||||||
|
|
||||||
|
_nlp.initialize()
|
||||||
|
|
||||||
|
_examples = make_get_examples_multi_label(_nlp)
|
||||||
|
for i in range(5):
|
||||||
|
_nlp.update(_examples)
|
||||||
|
|
||||||
|
return _nlp, _examples
|
||||||
|
|
||||||
|
with make_tempdir() as docs_dir:
|
||||||
|
# Check whether find_threshold() identifies lowest threshold above 0 as (first) ideal threshold, as this matches
|
||||||
|
# the current model behavior with the examples above. This can break once the model behavior changes and serves
|
||||||
|
# mostly as a smoke test.
|
||||||
|
nlp, examples = init_nlp()
|
||||||
|
DocBin(docs=[example.reference for example in examples]).to_disk(
|
||||||
|
docs_dir / "docs"
|
||||||
|
)
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
assert (
|
||||||
|
find_threshold(nlp_dir, docs_dir / "docs", verbose=False)[0]
|
||||||
|
== numpy.linspace(0, 1, 10)[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Specifying name of non-MultiLabel_TextCategorizer component should fail.
|
||||||
|
nlp, _ = init_nlp(("sentencizer",))
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
with pytest.raises(SystemExit) as error:
|
||||||
|
find_threshold(nlp_dir, docs_dir / "docs", pipe_name="sentencizer")
|
||||||
|
assert error.value.code == 1
|
||||||
|
|
||||||
|
# Having multiple textcat_multilabel components without specifying the name should fail.
|
||||||
|
nlp, _ = init_nlp(("textcat_multilabel",))
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
with pytest.raises(SystemExit) as error:
|
||||||
|
find_threshold(nlp_dir, docs_dir / "docs")
|
||||||
|
assert error.value.code == 1
|
||||||
|
|
||||||
|
# Having multiple textcat_multilabel components should work when specifying the name.
|
||||||
|
nlp, _ = init_nlp(("textcat_multilabel",))
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
assert (
|
||||||
|
find_threshold(
|
||||||
|
nlp_dir, docs_dir / "docs", pipe_name="tc_multi", verbose=False
|
||||||
|
)[0]
|
||||||
|
== numpy.linspace(0, 1, 10)[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Specifying the name of an non-existing pipe should fail.
|
||||||
|
nlp, _ = init_nlp()
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
with pytest.raises(SystemExit) as error:
|
||||||
|
find_threshold(nlp_dir, docs_dir / "docs", pipe_name="_")
|
||||||
|
assert error.value.code == 1
|
||||||
|
|
||||||
|
# Using a pipe with no textcat components should fail.
|
||||||
|
nlp = English()
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
with pytest.raises(SystemExit) as error:
|
||||||
|
find_threshold(nlp_dir, docs_dir / "docs")
|
||||||
|
assert error.value.code == 1
|
||||||
|
|
||||||
|
# Specifying scores not in range 0 <= x <= 1 should fail.
|
||||||
|
nlp, _ = init_nlp()
|
||||||
|
DocBin(
|
||||||
|
docs=[
|
||||||
|
Example.from_dict(nlp.make_doc(t[0]), t[1]).reference
|
||||||
|
for t in [
|
||||||
|
(
|
||||||
|
"I'm angry and confused",
|
||||||
|
{"cats": {"ANGRY": 1.0, "CONFUSED": 2.0, "HAPPY": 0.0}},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"I'm confused but happy",
|
||||||
|
{"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
).to_disk(docs_dir / "docs")
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
with pytest.raises(SystemExit) as error:
|
||||||
|
find_threshold(nlp_dir, docs_dir / "docs")
|
||||||
|
assert error.value.code == 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user