mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-08 06:04:57 +03:00
Finish first draft for find-threshold.
This commit is contained in:
parent
0e5cd6b0c0
commit
4981700ced
|
@ -28,7 +28,6 @@ from .project.push import project_push # noqa: F401
|
||||||
from .project.pull import project_pull # noqa: F401
|
from .project.pull import project_pull # noqa: F401
|
||||||
from .project.document import project_document # noqa: F401
|
from .project.document import project_document # noqa: F401
|
||||||
from .find_threshold import find_threshold # noqa: F401
|
from .find_threshold import find_threshold # noqa: F401
|
||||||
from .find_threshold import find_threshold_cli # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("link", no_args_is_help=True, deprecated=True, hidden=True)
|
@app.command("link", no_args_is_help=True, deprecated=True, hidden=True)
|
||||||
|
|
|
@ -2,19 +2,19 @@ from pathlib import Path
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# import numpy
|
import numpy
|
||||||
|
import wasabi.tables
|
||||||
|
|
||||||
import spacy
|
|
||||||
from ._util import app, Arg, Opt
|
from ._util import app, Arg, Opt
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..pipeline import MultiLabel_TextCategorizer
|
from ..pipeline import MultiLabel_TextCategorizer
|
||||||
|
from ..tokens import DocBin
|
||||||
|
|
||||||
_DEFAULTS = {
|
_DEFAULTS = {
|
||||||
"aggregation": "weighted",
|
"average": "micro",
|
||||||
"pipe_name": None,
|
"pipe_name": None,
|
||||||
"n_trials": 10,
|
"n_trials": 10,
|
||||||
"beta": 1,
|
"beta": 1,
|
||||||
"reverse": False,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,11 +26,10 @@ def find_threshold_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
model_path: Path = Arg(..., help="Path to model file", exists=True, allow_dash=True),
|
model_path: Path = Arg(..., help="Path to model file", exists=True, allow_dash=True),
|
||||||
doc_path: Path = Arg(..., help="Path to doc bin file", exists=True, allow_dash=True),
|
doc_path: Path = Arg(..., help="Path to doc bin file", exists=True, allow_dash=True),
|
||||||
aggregation: str = Arg(_DEFAULTS["aggregation"], help="How to aggregate F-scores over labels. One of ('micro', 'macro', 'weighted')", exists=True, allow_dash=True),
|
average: str = Arg(_DEFAULTS["average"], help="How to aggregate F-scores over labels. One of ('micro', 'macro')", exists=True, allow_dash=True),
|
||||||
pipe_name: Optional[str] = Opt(_DEFAULTS["pipe_name"], "--pipe_name", "-p", help="Name of pipe to examine thresholds for"),
|
pipe_name: Optional[str] = Opt(_DEFAULTS["pipe_name"], "--pipe_name", "-p", help="Name of pipe to examine thresholds for"),
|
||||||
n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
|
n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
|
||||||
beta: float = Opt(_DEFAULTS["beta"], "--beta", help="Beta for F1 calculation. Ignored if different metric is used"),
|
beta: float = Opt(_DEFAULTS["beta"], "--beta", help="Beta for F1 calculation. Ignored if different metric is used"),
|
||||||
reverse: bool = Opt(_DEFAULTS["reverse"], "--reverse", "-r", help="Minimizes metric instead of maximizing it."),
|
|
||||||
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
):
|
):
|
||||||
|
@ -38,12 +37,11 @@ def find_threshold_cli(
|
||||||
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric from CLI.
|
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric from CLI.
|
||||||
model_path (Path): Path to file with trained model.
|
model_path (Path): Path to file with trained model.
|
||||||
doc_path (Path): Path to file with DocBin with docs to use for threshold search.
|
doc_path (Path): Path to file with DocBin with docs to use for threshold search.
|
||||||
aggregation (str): How to aggregate F-scores across labels. One of ('micro', 'macro', 'weighted').
|
average (str): How to average F-scores across labels. One of ('micro', 'macro').
|
||||||
pipe_name (Optional[str]): Name of pipe to examine thresholds for. If None, pipe of type MultiLabel_TextCategorizer
|
pipe_name (Optional[str]): Name of pipe to examine thresholds for. If None, pipe of type MultiLabel_TextCategorizer
|
||||||
is seleted. If there are multiple, an error is raised.
|
is seleted. If there are multiple, an error is raised.
|
||||||
n_trials (int): Number of trials to determine optimal thresholds
|
n_trials (int): Number of trials to determine optimal thresholds
|
||||||
beta (float): Beta for F1 calculation. Ignored if different metric is used.
|
beta (float): Beta for F1 calculation. Ignored if different metric is used.
|
||||||
reverse (bool): Whether to minimize metric instead of maximizing it.
|
|
||||||
verbose (bool): Display more information for debugging purposes
|
verbose (bool): Display more information for debugging purposes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -51,11 +49,10 @@ def find_threshold_cli(
|
||||||
find_threshold(
|
find_threshold(
|
||||||
model_path,
|
model_path,
|
||||||
doc_path,
|
doc_path,
|
||||||
aggregation=aggregation,
|
average=average,
|
||||||
pipe_name=pipe_name,
|
pipe_name=pipe_name,
|
||||||
n_trials=n_trials,
|
n_trials=n_trials,
|
||||||
beta=beta,
|
beta=beta,
|
||||||
reverse=reverse,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,59 +60,148 @@ def find_threshold(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
doc_path: Path,
|
doc_path: Path,
|
||||||
*,
|
*,
|
||||||
aggregation: str = _DEFAULTS["aggregation"],
|
average: str = _DEFAULTS["average"],
|
||||||
pipe_name: Optional[str] = _DEFAULTS["pipe_name"],
|
pipe_name: Optional[str] = _DEFAULTS["pipe_name"],
|
||||||
n_trials: int = _DEFAULTS["n_trials"],
|
n_trials: int = _DEFAULTS["n_trials"],
|
||||||
beta: float = _DEFAULTS["beta"],
|
beta: float = _DEFAULTS["beta"],
|
||||||
reverse: bool = _DEFAULTS["reverse"]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
|
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
|
||||||
model_path (Path): Path to file with trained model.
|
model_path (Path): Path to file with trained model.
|
||||||
doc_path (Path): Path to file with DocBin with docs to use for threshold search.
|
doc_path (Path): Path to file with DocBin with docs to use for threshold search.
|
||||||
aggregation (str): How to aggregate F-scores across labels. One of ('micro', 'macro', 'weighted').
|
average (str): How to average F-scores across labels. One of ('micro', 'macro').
|
||||||
pipe_name (Optional[str]): Name of pipe to examine thresholds for. If None, pipe of type MultiLabel_TextCategorizer
|
pipe_name (Optional[str]): Name of pipe to examine thresholds for. If None, pipe of type MultiLabel_TextCategorizer
|
||||||
is seleted. If there are multiple, an error is raised.
|
is seleted. If there are multiple, an error is raised.
|
||||||
n_trials (int): Number of trials to determine optimal thresholds
|
n_trials (int): Number of trials to determine optimal thresholds
|
||||||
beta (float): Beta for F1 calculation. Ignored if different metric is used.
|
beta (float): Beta for F1 calculation. Ignored if different metric is used.
|
||||||
reverse (bool): Whether to minimize metric instead of maximizing it.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nlp = spacy.load(model_path)
|
nlp = util.load_model(model_path)
|
||||||
pipe: Optional[MultiLabel_TextCategorizer] = None
|
pipe: Optional[MultiLabel_TextCategorizer] = None
|
||||||
selected_pipe_name: Optional[str] = pipe_name
|
selected_pipe_name: Optional[str] = pipe_name
|
||||||
|
|
||||||
|
if average not in ("micro", "macro"):
|
||||||
|
wasabi.msg.fail(
|
||||||
|
"Expected 'micro' or 'macro' for F-score averaging method, received '{avg_method}'.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
|
||||||
for _pipe_name, _pipe in nlp.pipeline:
|
for _pipe_name, _pipe in nlp.pipeline:
|
||||||
if pipe_name and _pipe_name == pipe_name:
|
if pipe_name and _pipe_name == pipe_name:
|
||||||
if not isinstance(_pipe, MultiLabel_TextCategorizer):
|
if not isinstance(_pipe, MultiLabel_TextCategorizer):
|
||||||
# todo convert to error
|
wasabi.msg.fail(
|
||||||
assert "Specified name is not a MultiLabel_TextCategorizer."
|
"Specified component {component} is not of type `MultiLabel_TextCategorizer`.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
pipe = _pipe
|
pipe = _pipe
|
||||||
break
|
break
|
||||||
elif pipe_name is None:
|
elif pipe_name is None:
|
||||||
if isinstance(_pipe, MultiLabel_TextCategorizer):
|
if isinstance(_pipe, MultiLabel_TextCategorizer):
|
||||||
if pipe:
|
if pipe:
|
||||||
# todo convert to error
|
wasabi.msg.fail(
|
||||||
assert (
|
"Multiple components of type `MultiLabel_TextCategorizer` exist in pipeline. Specify name of "
|
||||||
"Multiple components of type MultiLabel_TextCategorizer in pipeline. Please specify "
|
"component to evaluate.",
|
||||||
"component name."
|
exits=1,
|
||||||
)
|
)
|
||||||
pipe = _pipe
|
pipe = _pipe
|
||||||
selected_pipe_name = _pipe_name
|
selected_pipe_name = _pipe_name
|
||||||
|
|
||||||
# counts = {label: 0 for label in pipe.labels}
|
if pipe is None:
|
||||||
# true_positive_counts = counts.copy()
|
if pipe_name:
|
||||||
# false_positive_counts = counts.copy()
|
wasabi.msg.fail(
|
||||||
# f_scores = counts.copy()
|
f"No component with name {pipe_name} found in pipeline.", exits=1
|
||||||
# thresholds = numpy.linspace(0, 1, n_trials)
|
)
|
||||||
|
wasabi.msg.fail(
|
||||||
|
"No component of type `MultiLabel_TextCategorizer` found in pipeline.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
|
||||||
# todo iterate over docs, assert categories are 1 or 0.
|
print(
|
||||||
# todo run pipe for all docs in docbin.
|
f"Searching threshold with the best {average} F-score for pipe '{selected_pipe_name}' with {n_trials} trials"
|
||||||
# todo iterate over thresholds. for each:
|
f" and beta = {beta}."
|
||||||
# - iterate over all docs. for each:
|
)
|
||||||
# - iterate over all labels. for each:
|
|
||||||
# - mark as positive/negative based on current threshold
|
thresholds = numpy.linspace(0, 1, n_trials)
|
||||||
# - update count, f_score stats
|
ref_pos_counts = {label: 0 for label in pipe.labels}
|
||||||
# - compute f_scores for all labels
|
pred_pos_counts = {
|
||||||
# - output best threshold
|
t: {True: ref_pos_counts.copy(), False: ref_pos_counts.copy()}
|
||||||
print(selected_pipe_name, pipe.labels, pipe.predict([nlp("aaa")]))
|
for t in thresholds
|
||||||
|
}
|
||||||
|
f_scores_per_label = {t: ref_pos_counts.copy() for t in thresholds}
|
||||||
|
f_scores = {t: 0 for t in thresholds}
|
||||||
|
|
||||||
|
# Count true/false positives for provided docs.
|
||||||
|
doc_bin = DocBin()
|
||||||
|
doc_bin.from_disk(doc_path)
|
||||||
|
for ref_doc in doc_bin.get_docs(nlp.vocab):
|
||||||
|
for label, score in ref_doc.cats.items():
|
||||||
|
if score not in (0, 1):
|
||||||
|
wasabi.msg.fail(
|
||||||
|
f"Expected category scores in evaluation dataset to be 0 <= x <= 1, received {score}.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
ref_pos_counts[label] += ref_doc.cats[label] == 1
|
||||||
|
|
||||||
|
pred_doc = nlp(ref_doc.text)
|
||||||
|
# Collect count stats per threshold value and label.
|
||||||
|
for threshold in thresholds:
|
||||||
|
for label, score in pred_doc.cats.items():
|
||||||
|
label_value = int(score >= threshold)
|
||||||
|
if label_value == ref_doc.cats[label] == 1:
|
||||||
|
pred_pos_counts[threshold][True][label] += 1
|
||||||
|
elif label_value == 1 and ref_doc.cats[label] == 0:
|
||||||
|
pred_pos_counts[threshold][False][label] += 1
|
||||||
|
|
||||||
|
# Compute f_scores.
|
||||||
|
for threshold in thresholds:
|
||||||
|
for label in ref_pos_counts:
|
||||||
|
n_pos_preds = (
|
||||||
|
pred_pos_counts[threshold][True][label]
|
||||||
|
+ pred_pos_counts[threshold][False][label]
|
||||||
|
)
|
||||||
|
precision = (
|
||||||
|
(pred_pos_counts[threshold][True][label] / n_pos_preds)
|
||||||
|
if n_pos_preds > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
recall = pred_pos_counts[threshold][True][label] / ref_pos_counts[label]
|
||||||
|
f_scores_per_label[threshold][label] = (
|
||||||
|
(
|
||||||
|
(1 + beta**2)
|
||||||
|
* (precision * recall / (precision * beta**2 + recall))
|
||||||
|
)
|
||||||
|
if precision
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Aggregate F-scores.
|
||||||
|
if average == "micro":
|
||||||
|
f_scores[threshold] = sum(
|
||||||
|
[
|
||||||
|
f_scores_per_label[threshold][label] * ref_pos_counts[label]
|
||||||
|
for label in ref_pos_counts
|
||||||
|
]
|
||||||
|
) / sum(ref_pos_counts.values())
|
||||||
|
else:
|
||||||
|
f_scores[threshold] = sum(
|
||||||
|
[f_scores_per_label[threshold][label] for label in ref_pos_counts]
|
||||||
|
) / len(ref_pos_counts)
|
||||||
|
|
||||||
|
best_threshold = max(f_scores, key=f_scores.get)
|
||||||
|
print(
|
||||||
|
f"Best threshold: {round(best_threshold, ndigits=4)} with F-score of {f_scores[best_threshold]}."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
wasabi.tables.table(
|
||||||
|
data=[
|
||||||
|
(threshold, label, f_score)
|
||||||
|
for threshold, label_f_scores in f_scores_per_label.items()
|
||||||
|
for label, f_score in label_f_scores.items()
|
||||||
|
],
|
||||||
|
header=["Threshold", "Label", "F-Score"],
|
||||||
|
),
|
||||||
|
wasabi.tables.table(
|
||||||
|
data=[(threshold, f_score) for threshold, f_score in f_scores.items()],
|
||||||
|
header=["Threshold", f"F-Score ({average})"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user