Add a spacy evaluate speed subcommand

This subcommand reports the mean batch performance of a model on a data set with
a 95% confidence interval. For reliability, it first performs some warmup
rounds. Then it will measure performance on batches with randomly shuffled
documents.

To avoid having too many spaCy commands, `speed` is a subcommand of `evaluate`
and accuracy evaluation is moved to its own `evaluate accuracy` subcommand.
This commit is contained in:
Daniël de Kok 2022-11-30 11:56:11 +01:00
parent 6f9d630f7e
commit 6c43081b78
5 changed files with 189 additions and 5 deletions

View File

@ -15,7 +15,8 @@ from .debug_data import debug_data # noqa: F401
from .debug_config import debug_config # noqa: F401 from .debug_config import debug_config # noqa: F401
from .debug_model import debug_model # noqa: F401 from .debug_model import debug_model # noqa: F401
from .debug_diff import debug_diff # noqa: F401 from .debug_diff import debug_diff # noqa: F401
from .evaluate import evaluate # noqa: F401 from .evaluate_accuracy import evaluate # noqa: F401
from .evaluate_speed import evaluate_cli # noqa: F401
from .convert import convert # noqa: F401 from .convert import convert # noqa: F401
from .init_pipeline import init_pipeline_cli # noqa: F401 from .init_pipeline import init_pipeline_cli # noqa: F401
from .init_config import init_config, fill_config # noqa: F401 from .init_config import init_config, fill_config # noqa: F401

View File

@ -46,6 +46,7 @@ DEBUG_HELP = """Suite of helpful commands for debugging and profiling. Includes
commands to check and validate your config files, training and evaluation data, commands to check and validate your config files, training and evaluation data,
and custom model implementations. and custom model implementations.
""" """
EVALUATE_HELP = """Commands for evaluating pipelines."""
INIT_HELP = """Commands for initializing configs and pipeline packages.""" INIT_HELP = """Commands for initializing configs and pipeline packages."""
# Wrappers for Typer's annotations. Initially created to set defaults and to # Wrappers for Typer's annotations. Initially created to set defaults and to
@ -56,10 +57,12 @@ Opt = typer.Option
app = typer.Typer(name=NAME, help=HELP) app = typer.Typer(name=NAME, help=HELP)
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True) project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True) debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
evaluate_cli = typer.Typer(name="evaluate", help=EVALUATE_HELP, no_args_is_help=True)
init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True) init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True)
app.add_typer(project_cli) app.add_typer(project_cli)
app.add_typer(debug_cli) app.add_typer(debug_cli)
app.add_typer(evaluate_cli)
app.add_typer(init_cli) app.add_typer(init_cli)

View File

@ -4,18 +4,20 @@ from pathlib import Path
import re import re
import srsly import srsly
from thinc.api import fix_random_seed from thinc.api import fix_random_seed
import typer
from ..training import Corpus from ..training import Corpus
from ..tokens import Doc from ..tokens import Doc
from ._util import app, Arg, Opt, setup_gpu, import_code from ._util import Arg, Opt, evaluate_cli, setup_gpu, import_code
from ..scorer import Scorer from ..scorer import Scorer
from .. import util from .. import util
from .. import displacy from .. import displacy
@app.command("evaluate") @evaluate_cli.command("accuracy", context_settings={"allow_extra_args": True, "ignore_unknown_options": True},)
def evaluate_cli( def accuracy_cli(
# fmt: off # fmt: off
ctx: typer.Context,
model: str = Arg(..., help="Model name or path"), model: str = Arg(..., help="Model name or path"),
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True), data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
output: Optional[Path] = Opt(None, "--output", "-o", help="Output JSON file for metrics", dir_okay=False), output: Optional[Path] = Opt(None, "--output", "-o", help="Output JSON file for metrics", dir_okay=False),

178
spacy/cli/evaluate_speed.py Normal file
View File

@ -0,0 +1,178 @@
from typing import Iterable, List, Optional
import random
from itertools import islice
import numpy
from pathlib import Path
import time
from tqdm import tqdm
import typer
from .. import Language, util
from ..tokens import Doc
from ..training import Corpus
from ._util import Arg, Opt, evaluate_cli, setup_gpu
@evaluate_cli.command(
"speed",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
def benchmark_cli(
ctx: typer.Context,
model: str = Arg(..., help="Model name or path"),
data_path: Path = Arg(
..., help="Location of binary evaluation data in .spacy format", exists=True
),
batch_size: Optional[int] = Opt(
None, "--batch-size", "-b", min=1, help="Override the pipeline batch size"
),
no_shuffle: bool = Opt(False, "--no-shuffle", help="Do not shuffle benchmark data"),
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
n_batches: int = Opt(
50,
"--batches",
help="Minimum number of batches to benchmark",
min=30,
),
warmup_epochs: int = Opt(
3, "--warmup", "-w", min=0, help="Number of iterations over the data for warmup"
),
):
"""
Benchmark a pipeline. Expects a loadable spaCy pipeline and benchmark
data in the binary .spacy format.
"""
setup_gpu(use_gpu=use_gpu, silent=False)
nlp = util.load_model(model)
batch_size = batch_size if batch_size is not None else nlp.batch_size
corpus = Corpus(data_path)
docs = [eg.predicted for eg in corpus(nlp)]
print(f"Warming up for {warmup_epochs} epochs...")
warmup(nlp, docs, warmup_epochs, batch_size)
print()
print(f"Benchmarking {n_batches} batches...")
wps = benchmark(nlp, docs, n_batches, batch_size, not no_shuffle)
print()
print_outliers(wps)
print_mean_with_ci(wps)
# Lowercased, behaves as a context manager function.
class time_context:
"""Register the running time of a context."""
def __enter__(self):
self.start = time.perf_counter()
return self
def __exit__(self, type, value, traceback):
self.elapsed = time.perf_counter() - self.start
class Quartiles:
"""Calculate the q1, q2, q3 quartiles and the inter-quartile range (iqr)
of a sample."""
q1: float
q2: float
q3: float
iqr: float
def __init__(self, sample: numpy.ndarray) -> None:
self.q1 = numpy.quantile(sample, 0.25)
self.q2 = numpy.quantile(sample, 0.5)
self.q3 = numpy.quantile(sample, 0.75)
self.iqr = self.q3 - self.q1
def annotate(
nlp: Language, docs: List[Doc], batch_size: Optional[int]
) -> numpy.ndarray:
docs = nlp.pipe(tqdm(docs, unit="doc"), batch_size=batch_size)
wps = []
while True:
with time_context() as elapsed:
batch_docs = list(
islice(docs, batch_size if batch_size else nlp.batch_size)
)
if len(batch_docs) == 0:
break
n_tokens = count_tokens(batch_docs)
wps.append(n_tokens / elapsed.elapsed)
return numpy.array(wps)
def benchmark(
nlp: Language,
docs: List[Doc],
n_batches: int,
batch_size: int,
shuffle: bool,
) -> numpy.ndarray:
if shuffle:
bench_docs = [
nlp.make_doc(random.choice(docs).text)
for _ in range(n_batches * batch_size)
]
else:
bench_docs = [
nlp.make_doc(docs[i % len(docs)].text)
for i in range(n_batches * batch_size)
]
return annotate(nlp, bench_docs, batch_size)
def bootstrap(x, statistic=numpy.mean, iterations=10000) -> numpy.ndarray:
"""Apply a statistic to repeated random samples of an array."""
return numpy.fromiter(
(
statistic(numpy.random.choice(x, len(x), replace=True))
for _ in range(iterations)
),
numpy.float64,
)
def count_tokens(docs: Iterable[Doc]) -> int:
return sum(len(doc) for doc in docs)
def print_mean_with_ci(sample: numpy.ndarray):
mean = numpy.mean(sample)
bootstrap_means = bootstrap(sample)
bootstrap_means.sort()
# 95% confidence interval
low = bootstrap_means[int(len(bootstrap_means) * 0.025)]
high = bootstrap_means[int(len(bootstrap_means) * 0.975)]
print(f"Mean: {mean:.1f} WPS (95% CI: {low-mean:.1f} +{high-mean:.1f})")
def print_outliers(sample: numpy.ndarray):
quartiles = Quartiles(sample)
n_outliers = numpy.sum(
(sample < (quartiles.q1 - 1.5 * quartiles.iqr))
| (sample > (quartiles.q3 + 1.5 * quartiles.iqr))
)
n_extreme_outliers = numpy.sum(
(sample < (quartiles.q1 - 3.0 * quartiles.iqr))
| (sample > (quartiles.q3 + 3.0 * quartiles.iqr))
)
print(
f"Outliers: {(100 * n_outliers) / len(sample):.1f}%, extreme outliers: {(100 * n_extreme_outliers) / len(sample)}%"
)
def warmup(
nlp: Language, docs: List[Doc], warmup_epochs: int, batch_size: Optional[int]
) -> numpy.ndarray:
docs = warmup_epochs * docs
return annotate(nlp, docs, batch_size)

View File

@ -8,7 +8,7 @@ from wasabi import msg
import spacy import spacy
from spacy import util from spacy import util
from spacy.cli.evaluate import print_prf_per_type, print_textcats_auc_per_cat from spacy.cli.evaluate_accuracy import print_prf_per_type, print_textcats_auc_per_cat
from spacy.lang.en import English from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import TextCategorizer from spacy.pipeline import TextCategorizer