mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-08 06:04:57 +03:00
Generalize component and threshold handling. Harmonize arguments with 'spacy evaluate' CLI.
This commit is contained in:
parent
63c80288ef
commit
3a0a3854f7
|
@ -1,16 +1,23 @@
|
||||||
|
import functools
|
||||||
|
import operator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Any, Dict, List
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import wasabi.tables
|
import wasabi.tables
|
||||||
|
|
||||||
|
from ..training import Corpus
|
||||||
from ._util import app, Arg, Opt, import_code, setup_gpu
|
from ._util import app, Arg, Opt, import_code, setup_gpu
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..pipeline import MultiLabel_TextCategorizer, Pipe
|
|
||||||
from ..tokens import DocBin
|
|
||||||
|
|
||||||
_DEFAULTS = {"average": "micro", "n_trials": 10, "beta": 1, "use_gpu": -1}
|
_DEFAULTS = {
|
||||||
|
"average": "micro",
|
||||||
|
"n_trials": 10,
|
||||||
|
"beta": 1,
|
||||||
|
"use_gpu": -1,
|
||||||
|
"gold_preproc": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.command(
|
@app.command(
|
||||||
|
@ -21,12 +28,14 @@ def find_threshold_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
model: str = Arg(..., help="Model name or path"),
|
model: str = Arg(..., help="Model name or path"),
|
||||||
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
|
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
|
||||||
pipe_name: str = Opt(..., "--pipe_name", "-p", help="Name of pipe to examine thresholds for"),
|
pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"),
|
||||||
average: str = Arg(_DEFAULTS["average"], help="How to aggregate F-scores over labels. One of ('micro', 'macro')", exists=True, allow_dash=True),
|
threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"),
|
||||||
|
scores_key: str = Arg(..., help="Name of score to metric to optimize"),
|
||||||
n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
|
n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
|
||||||
beta: float = Opt(_DEFAULTS["beta"], "--beta", help="Beta for F1 calculation. Ignored if different metric is used"),
|
beta: float = Opt(_DEFAULTS["beta"], "--beta", help="Beta for F1 calculation. Ignored if different metric is used"),
|
||||||
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||||
use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
||||||
|
gold_preproc: bool = Opt(_DEFAULTS["gold_preproc"], "--gold-preproc", "-G", help="Use gold preprocessing"),
|
||||||
verbose: bool = Opt(False, "--silent", "-V", "-VV", help="Display more information for debugging purposes"),
|
verbose: bool = Opt(False, "--silent", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
):
|
):
|
||||||
|
@ -35,24 +44,30 @@ def find_threshold_cli(
|
||||||
model (Path): Path to file with trained model.
|
model (Path): Path to file with trained model.
|
||||||
data_path (Path): Path to file with DocBin with docs to use for threshold search.
|
data_path (Path): Path to file with DocBin with docs to use for threshold search.
|
||||||
pipe_name (str): Name of pipe to examine thresholds for.
|
pipe_name (str): Name of pipe to examine thresholds for.
|
||||||
average (str): How to average F-scores across labels. One of ('micro', 'macro').
|
threshold_key (str): Key of threshold attribute in component's configuration.
|
||||||
|
scores_key (str): Name of score to metric to optimize.
|
||||||
n_trials (int): Number of trials to determine optimal thresholds
|
n_trials (int): Number of trials to determine optimal thresholds
|
||||||
beta (float): Beta for F1 calculation.
|
beta (float): Beta for F1 calculation.
|
||||||
code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported.
|
code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported.
|
||||||
use_gpu (int): GPU ID or -1 for CPU.
|
use_gpu (int): GPU ID or -1 for CPU.
|
||||||
|
gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the
|
||||||
|
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
|
||||||
|
to train/test skew.
|
||||||
silent (bool): Display more information for debugging purposes
|
silent (bool): Display more information for debugging purposes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||||
import_code(code_path)
|
import_code(code_path)
|
||||||
find_threshold(
|
find_threshold(
|
||||||
model,
|
model=model,
|
||||||
data_path,
|
data_path=data_path,
|
||||||
pipe_name=pipe_name,
|
pipe_name=pipe_name,
|
||||||
average=average,
|
threshold_key=threshold_key,
|
||||||
|
scores_key=scores_key,
|
||||||
n_trials=n_trials,
|
n_trials=n_trials,
|
||||||
beta=beta,
|
beta=beta,
|
||||||
use_gpu=use_gpu,
|
use_gpu=use_gpu,
|
||||||
|
gold_preproc=gold_preproc,
|
||||||
silent=False,
|
silent=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,12 +75,14 @@ def find_threshold_cli(
|
||||||
def find_threshold(
|
def find_threshold(
|
||||||
model: str,
|
model: str,
|
||||||
data_path: Path,
|
data_path: Path,
|
||||||
|
pipe_name: str,
|
||||||
|
threshold_key: str,
|
||||||
|
scores_key: str,
|
||||||
*,
|
*,
|
||||||
pipe_name: str, # type: ignore
|
n_trials: int = _DEFAULTS["n_trials"],
|
||||||
average: str = _DEFAULTS["average"], # type: ignore
|
beta: float = _DEFAULTS["beta"],
|
||||||
n_trials: int = _DEFAULTS["n_trials"], # type: ignore
|
|
||||||
beta: float = _DEFAULTS["beta"], # type: ignore,
|
|
||||||
use_gpu: int = _DEFAULTS["use_gpu"],
|
use_gpu: int = _DEFAULTS["use_gpu"],
|
||||||
|
gold_preproc: bool = _DEFAULTS["gold_preproc"],
|
||||||
silent: bool = True,
|
silent: bool = True,
|
||||||
) -> Tuple[float, float]:
|
) -> Tuple[float, float]:
|
||||||
"""
|
"""
|
||||||
|
@ -73,10 +90,14 @@ def find_threshold(
|
||||||
model (Union[str, Path]): Path to file with trained model.
|
model (Union[str, Path]): Path to file with trained model.
|
||||||
data_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search.
|
data_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search.
|
||||||
pipe_name (str): Name of pipe to examine thresholds for.
|
pipe_name (str): Name of pipe to examine thresholds for.
|
||||||
average (str): How to average F-scores across labels. One of ('micro', 'macro').
|
threshold_key (str): Key of threshold attribute in component's configuration.
|
||||||
|
scores_key (str): Name of score to metric to optimize.
|
||||||
n_trials (int): Number of trials to determine optimal thresholds.
|
n_trials (int): Number of trials to determine optimal thresholds.
|
||||||
beta (float): Beta for F1 calculation.
|
beta (float): Beta for F1 calculation.
|
||||||
use_gpu (int): GPU ID or -1 for CPU.
|
use_gpu (int): GPU ID or -1 for CPU.
|
||||||
|
gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the
|
||||||
|
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
|
||||||
|
to train/test skew.
|
||||||
silent (bool): Whether to print non-error-related output to stdout.
|
silent (bool): Whether to print non-error-related output to stdout.
|
||||||
RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score.
|
RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score.
|
||||||
"""
|
"""
|
||||||
|
@ -86,127 +107,57 @@ def find_threshold(
|
||||||
if not data_path.exists():
|
if not data_path.exists():
|
||||||
wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
|
wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
|
||||||
nlp = util.load_model(model)
|
nlp = util.load_model(model)
|
||||||
pipe: Optional[Pipe] = None
|
|
||||||
selected_pipe_name: Optional[str] = pipe_name
|
|
||||||
|
|
||||||
if average not in ("micro", "macro"):
|
try:
|
||||||
wasabi.msg.fail(
|
pipe = nlp.get_pipe(pipe_name)
|
||||||
"Expected 'micro' or 'macro' for F-score averaging method, received '{avg_method}'.",
|
except KeyError as err:
|
||||||
exits=1,
|
wasabi.msg.fail(title=str(err), exits=1)
|
||||||
)
|
|
||||||
|
|
||||||
for _pipe_name, _pipe in nlp.pipeline:
|
if not silent:
|
||||||
# todo instead of instance check, assert _pipe has a .threshold arg
|
wasabi.msg.info(
|
||||||
# won't work, actually. e.g. spancat doesn't .threshold.
|
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
|
||||||
if _pipe_name == pipe_name:
|
|
||||||
if not isinstance(_pipe, MultiLabel_TextCategorizer):
|
|
||||||
wasabi.msg.fail(
|
|
||||||
"Specified component '{component}' is not of type `MultiLabel_TextCategorizer`.".format(
|
|
||||||
component=pipe_name
|
|
||||||
),
|
|
||||||
exits=1,
|
|
||||||
)
|
|
||||||
pipe = _pipe
|
|
||||||
break
|
|
||||||
|
|
||||||
if pipe is None:
|
|
||||||
wasabi.msg.fail(
|
|
||||||
f"No component with name {pipe_name} found in pipeline.", exits=1
|
|
||||||
)
|
|
||||||
# This is purely for MyPy. Type checking is done in loop above already.
|
|
||||||
assert isinstance(pipe, MultiLabel_TextCategorizer)
|
|
||||||
|
|
||||||
if silent:
|
|
||||||
print(
|
|
||||||
f"Searching threshold with the best {average} F-score for component '{selected_pipe_name}' with {n_trials} "
|
|
||||||
f"trials and beta = {beta}."
|
f"trials and beta = {beta}."
|
||||||
)
|
)
|
||||||
|
|
||||||
thresholds = numpy.linspace(0, 1, n_trials)
|
# Load evaluation corpus.
|
||||||
# todo use Scorer.score_cats. possibly to be extended?
|
corpus = Corpus(data_path, gold_preproc=gold_preproc)
|
||||||
ref_pos_counts = {label: 0 for label in pipe.labels}
|
dev_dataset = list(corpus(nlp))
|
||||||
pred_pos_counts = {
|
config_keys = threshold_key.split(".")
|
||||||
t: {True: ref_pos_counts.copy(), False: ref_pos_counts.copy()}
|
|
||||||
for t in thresholds
|
|
||||||
}
|
|
||||||
f_scores_per_label = {t: {label: 0.0 for label in pipe.labels} for t in thresholds}
|
|
||||||
f_scores = {t: 0.0 for t in thresholds}
|
|
||||||
|
|
||||||
# Count true/false positives for provided docs.
|
def set_nested_item(
|
||||||
doc_bin = DocBin()
|
config: Dict[str, Any], keys: List[str], value: float
|
||||||
doc_bin.from_disk(data_path)
|
) -> Dict[str, Any]:
|
||||||
for ref_doc in doc_bin.get_docs(nlp.vocab):
|
"""Set item in nested dictionary. Adapated from https://stackoverflow.com/a/54138200.
|
||||||
for label, score in ref_doc.cats.items():
|
config (Dict[str, Any]): Configuration dictionary.
|
||||||
if score not in (0, 1):
|
keys (List[Any]):
|
||||||
wasabi.msg.fail(
|
value (float): Value to set.
|
||||||
f"Expected category scores in evaluation dataset to be 0 <= x <= 1, received {score}.",
|
RETURNS (Dict[str, Any]): Updated dictionary.
|
||||||
exits=1,
|
"""
|
||||||
)
|
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
|
||||||
ref_pos_counts[label] += ref_doc.cats[label] == 1
|
return config
|
||||||
|
|
||||||
pred_doc = nlp(ref_doc.text)
|
# Evaluate with varying threshold values.
|
||||||
# Collect count stats per threshold value and label.
|
scores: Dict[float, float] = {}
|
||||||
for threshold in thresholds:
|
for threshold in numpy.linspace(0, 1, n_trials):
|
||||||
for label, score in pred_doc.cats.items():
|
pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold)
|
||||||
if label not in pipe.labels:
|
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
|
||||||
continue
|
if not (
|
||||||
label_value = int(score >= threshold)
|
isinstance(scores[threshold], float) or isinstance(scores[threshold], int)
|
||||||
if label_value == ref_doc.cats[label] == 1:
|
):
|
||||||
pred_pos_counts[threshold][True][label] += 1
|
wasabi.msg.fail(
|
||||||
elif label_value == 1 and ref_doc.cats[label] == 0:
|
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "
|
||||||
pred_pos_counts[threshold][False][label] += 1
|
f"scores.",
|
||||||
|
exits=1,
|
||||||
# Compute F-scores.
|
|
||||||
for threshold in thresholds:
|
|
||||||
for label in ref_pos_counts:
|
|
||||||
n_pos_preds = (
|
|
||||||
pred_pos_counts[threshold][True][label]
|
|
||||||
+ pred_pos_counts[threshold][False][label]
|
|
||||||
)
|
|
||||||
precision = (
|
|
||||||
(pred_pos_counts[threshold][True][label] / n_pos_preds)
|
|
||||||
if n_pos_preds > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
recall = pred_pos_counts[threshold][True][label] / ref_pos_counts[label]
|
|
||||||
f_scores_per_label[threshold][label] = (
|
|
||||||
(
|
|
||||||
(1 + beta**2)
|
|
||||||
* (precision * recall / (precision * beta**2 + recall))
|
|
||||||
)
|
|
||||||
if precision
|
|
||||||
else 0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Aggregate F-scores.
|
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
|
||||||
if average == "micro":
|
if not silent:
|
||||||
f_scores[threshold] = sum(
|
|
||||||
[
|
|
||||||
f_scores_per_label[threshold][label] * ref_pos_counts[label]
|
|
||||||
for label in ref_pos_counts
|
|
||||||
]
|
|
||||||
) / sum(ref_pos_counts.values())
|
|
||||||
else:
|
|
||||||
f_scores[threshold] = sum(
|
|
||||||
[f_scores_per_label[threshold][label] for label in ref_pos_counts]
|
|
||||||
) / len(ref_pos_counts)
|
|
||||||
|
|
||||||
best_threshold = max(f_scores.keys(), key=(lambda key: f_scores[key]))
|
|
||||||
if silent:
|
|
||||||
print(
|
print(
|
||||||
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}.",
|
f"Best threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}.",
|
||||||
wasabi.tables.table(
|
wasabi.tables.table(
|
||||||
data=[
|
data=[(threshold, score) for threshold, score in scores.items()],
|
||||||
(threshold, label, f_score)
|
header=["Threshold", f"{scores_key}"],
|
||||||
for threshold, label_f_scores in f_scores_per_label.items()
|
|
||||||
for label, f_score in label_f_scores.items()
|
|
||||||
],
|
|
||||||
header=["Threshold", "Label", "F-Score"],
|
|
||||||
),
|
|
||||||
wasabi.tables.table(
|
|
||||||
data=[(threshold, f_score) for threshold, f_score in f_scores.items()],
|
|
||||||
header=["Threshold", f"F-Score ({average})"],
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return best_threshold, f_scores[best_threshold]
|
return best_threshold, scores[best_threshold]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
from typing import Counter, Iterable, Tuple, List
|
from typing import Counter, Tuple, List, Dict, Any
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -36,7 +36,7 @@ from spacy.tokens.span import Span
|
||||||
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
||||||
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
||||||
from spacy.training.converters import iob_to_docs
|
from spacy.training.converters import iob_to_docs
|
||||||
from spacy.pipeline import TextCategorizer, Pipe
|
from spacy.pipeline import TextCategorizer, Pipe, SpanCategorizer
|
||||||
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
|
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
|
||||||
|
|
||||||
from ..cli.init_pipeline import _init_labels
|
from ..cli.init_pipeline import _init_labels
|
||||||
|
@ -860,38 +860,55 @@ def test_span_length_freq_dist_output_must_be_correct():
|
||||||
|
|
||||||
|
|
||||||
def test_cli_find_threshold(capsys):
|
def test_cli_find_threshold(capsys):
|
||||||
def make_get_examples_multi_label(_nlp: Language) -> List[Example]:
|
def make_examples(_nlp: Language) -> List[Example]:
|
||||||
return [
|
docs: List[Example] = []
|
||||||
Example.from_dict(_nlp.make_doc(t[0]), t[1])
|
|
||||||
for t in [
|
for t in [
|
||||||
(
|
(
|
||||||
"I'm angry and confused",
|
"I'm angry and confused in the Bank of America.",
|
||||||
{"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0}},
|
{
|
||||||
),
|
"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0},
|
||||||
(
|
"spans": {"sc": [(7, 10, "ORG")]},
|
||||||
"I'm confused but happy",
|
},
|
||||||
{"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}},
|
),
|
||||||
),
|
(
|
||||||
]
|
"I'm confused but happy in New York.",
|
||||||
]
|
{
|
||||||
|
"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0},
|
||||||
|
"spans": {"sc": [(6, 7, "GPE")]},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]:
|
||||||
|
doc = _nlp.make_doc(t[0])
|
||||||
|
docs.append(Example.from_dict(doc, t[1]))
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
def init_nlp(
|
def init_nlp(
|
||||||
component_factory_names: Tuple[str, ...] = (),
|
components: Tuple[Tuple[str, Dict[str, Any]], ...] = ()
|
||||||
) -> Tuple[Language, List[Example]]:
|
) -> Tuple[Language, List[Example]]:
|
||||||
_nlp = English()
|
_nlp = English()
|
||||||
|
textcat: TextCategorizer = _nlp.add_pipe( # type: ignore
|
||||||
|
factory_name="textcat_multilabel",
|
||||||
|
name="tc_multi",
|
||||||
|
config={"threshold": 0.9},
|
||||||
|
)
|
||||||
|
textcat_labels = ("ANGRY", "CONFUSED", "HAPPY")
|
||||||
|
for label in textcat_labels:
|
||||||
|
textcat.add_label(label)
|
||||||
|
|
||||||
textcat: TextCategorizer = _nlp.add_pipe(factory_name="textcat_multilabel", name="tc_multi") # type: ignore
|
# Append additional components to pipeline.
|
||||||
textcat.add_label("ANGRY")
|
for cfn, comp_config in components:
|
||||||
textcat.add_label("CONFUSED")
|
comp = _nlp.add_pipe(cfn, config=comp_config)
|
||||||
textcat.add_label("HAPPY")
|
|
||||||
for cfn in component_factory_names:
|
|
||||||
comp = _nlp.add_pipe(cfn)
|
|
||||||
if isinstance(comp, TextCategorizer):
|
if isinstance(comp, TextCategorizer):
|
||||||
comp.add_label("dummy")
|
for label in textcat_labels:
|
||||||
|
comp.add_label(label)
|
||||||
|
if isinstance(comp, SpanCategorizer):
|
||||||
|
comp.add_label("GPE")
|
||||||
|
comp.add_label("ORG")
|
||||||
|
|
||||||
_nlp.initialize()
|
_nlp.initialize()
|
||||||
|
_examples = make_examples(_nlp)
|
||||||
_examples = make_get_examples_multi_label(_nlp)
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
_nlp.update(_examples)
|
_nlp.update(_examples)
|
||||||
|
|
||||||
|
@ -903,77 +920,63 @@ def test_cli_find_threshold(capsys):
|
||||||
# mostly as a smoke test.
|
# mostly as a smoke test.
|
||||||
nlp, examples = init_nlp()
|
nlp, examples = init_nlp()
|
||||||
DocBin(docs=[example.reference for example in examples]).to_disk(
|
DocBin(docs=[example.reference for example in examples]).to_disk(
|
||||||
docs_dir / "docs"
|
docs_dir / "docs.spacy"
|
||||||
)
|
)
|
||||||
with make_tempdir() as nlp_dir:
|
with make_tempdir() as nlp_dir:
|
||||||
nlp.to_disk(nlp_dir)
|
nlp.to_disk(nlp_dir)
|
||||||
assert (
|
assert (
|
||||||
find_threshold(nlp_dir, docs_dir / "docs", verbose=False)[0]
|
find_threshold(
|
||||||
|
model=nlp_dir,
|
||||||
|
data_path=docs_dir / "docs.spacy",
|
||||||
|
pipe_name="tc_multi",
|
||||||
|
threshold_key="threshold",
|
||||||
|
scores_key="cats_macro_f",
|
||||||
|
silent=True,
|
||||||
|
)[0]
|
||||||
== numpy.linspace(0, 1, 10)[1]
|
== numpy.linspace(0, 1, 10)[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# todo fix spancat test
|
||||||
# Specifying name of non-MultiLabel_TextCategorizer component should fail.
|
# Specifying name of non-MultiLabel_TextCategorizer component should fail.
|
||||||
nlp, _ = init_nlp(("sentencizer",))
|
nlp, _ = init_nlp((("spancat", {"spans_key": "sc", "threshold": 0.5}),))
|
||||||
with make_tempdir() as nlp_dir:
|
|
||||||
nlp.to_disk(nlp_dir)
|
|
||||||
with pytest.raises(SystemExit) as error:
|
|
||||||
find_threshold(nlp_dir, docs_dir / "docs", pipe_name="sentencizer")
|
|
||||||
assert error.value.code == 1
|
|
||||||
|
|
||||||
# Having multiple textcat_multilabel components without specifying the name should fail.
|
|
||||||
nlp, _ = init_nlp(("textcat_multilabel",))
|
|
||||||
with make_tempdir() as nlp_dir:
|
|
||||||
nlp.to_disk(nlp_dir)
|
|
||||||
with pytest.raises(SystemExit) as error:
|
|
||||||
find_threshold(nlp_dir, docs_dir / "docs")
|
|
||||||
assert error.value.code == 1
|
|
||||||
|
|
||||||
# Having multiple textcat_multilabel components should work when specifying the name.
|
|
||||||
nlp, _ = init_nlp(("textcat_multilabel",))
|
|
||||||
with make_tempdir() as nlp_dir:
|
with make_tempdir() as nlp_dir:
|
||||||
nlp.to_disk(nlp_dir)
|
nlp.to_disk(nlp_dir)
|
||||||
assert (
|
assert (
|
||||||
find_threshold(
|
find_threshold(
|
||||||
nlp_dir, docs_dir / "docs", pipe_name="tc_multi", verbose=False
|
model=nlp_dir,
|
||||||
|
data_path=docs_dir / "docs.spacy",
|
||||||
|
pipe_name="spancat",
|
||||||
|
threshold_key="threshold",
|
||||||
|
scores_key="spans_sc_f",
|
||||||
|
silent=True,
|
||||||
)[0]
|
)[0]
|
||||||
== numpy.linspace(0, 1, 10)[1]
|
== numpy.linspace(0, 1, 10)[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Having multiple textcat_multilabel components should work, since the name has to be specified.
|
||||||
|
nlp, _ = init_nlp((("textcat_multilabel", {}),))
|
||||||
|
with make_tempdir() as nlp_dir:
|
||||||
|
nlp.to_disk(nlp_dir)
|
||||||
|
assert find_threshold(
|
||||||
|
model=nlp_dir,
|
||||||
|
data_path=docs_dir / "docs.spacy",
|
||||||
|
pipe_name="tc_multi",
|
||||||
|
threshold_key="threshold",
|
||||||
|
scores_key="cats_macro_f",
|
||||||
|
silent=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Specifying the name of an non-existing pipe should fail.
|
# Specifying the name of an non-existing pipe should fail.
|
||||||
nlp, _ = init_nlp()
|
nlp, _ = init_nlp()
|
||||||
with make_tempdir() as nlp_dir:
|
with make_tempdir() as nlp_dir:
|
||||||
nlp.to_disk(nlp_dir)
|
nlp.to_disk(nlp_dir)
|
||||||
with pytest.raises(SystemExit) as error:
|
with pytest.raises(SystemExit) as error:
|
||||||
find_threshold(nlp_dir, docs_dir / "docs", pipe_name="_")
|
find_threshold(
|
||||||
assert error.value.code == 1
|
model=nlp_dir,
|
||||||
|
data_path=docs_dir / "docs.spacy",
|
||||||
# Using a pipe with no textcat components should fail.
|
pipe_name="_",
|
||||||
nlp = English()
|
threshold_key="threshold",
|
||||||
with make_tempdir() as nlp_dir:
|
scores_key="cats_macro_f",
|
||||||
nlp.to_disk(nlp_dir)
|
silent=True,
|
||||||
with pytest.raises(SystemExit) as error:
|
)
|
||||||
find_threshold(nlp_dir, docs_dir / "docs")
|
|
||||||
assert error.value.code == 1
|
|
||||||
|
|
||||||
# Specifying scores not in range 0 <= x <= 1 should fail.
|
|
||||||
nlp, _ = init_nlp()
|
|
||||||
DocBin(
|
|
||||||
docs=[
|
|
||||||
Example.from_dict(nlp.make_doc(t[0]), t[1]).reference
|
|
||||||
for t in [
|
|
||||||
(
|
|
||||||
"I'm angry and confused",
|
|
||||||
{"cats": {"ANGRY": 1.0, "CONFUSED": 2.0, "HAPPY": 0.0}},
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"I'm confused but happy",
|
|
||||||
{"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
]
|
|
||||||
).to_disk(docs_dir / "docs")
|
|
||||||
with make_tempdir() as nlp_dir:
|
|
||||||
nlp.to_disk(nlp_dir)
|
|
||||||
with pytest.raises(SystemExit) as error:
|
|
||||||
find_threshold(nlp_dir, docs_dir / "docs")
|
|
||||||
assert error.value.code == 1
|
assert error.value.code == 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user