mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-27 20:33:42 +03:00
Add distill subcommand (#13431)
* Add distill subcommand This subcommand distills a student model from a teacher model. * Fixes from Sofie Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Type and doc fixes * Wording * distill: document missing `-o` * Wording * Small fix --------- Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
304b9331e6
commit
fbc14aea45
|
@ -12,6 +12,7 @@ from .debug_config import debug_config # noqa: F401
|
||||||
from .debug_data import debug_data # noqa: F401
|
from .debug_data import debug_data # noqa: F401
|
||||||
from .debug_diff import debug_diff # noqa: F401
|
from .debug_diff import debug_diff # noqa: F401
|
||||||
from .debug_model import debug_model # noqa: F401
|
from .debug_model import debug_model # noqa: F401
|
||||||
|
from .distill import distill # noqa: F401
|
||||||
from .download import download # noqa: F401
|
from .download import download # noqa: F401
|
||||||
from .evaluate import evaluate # noqa: F401
|
from .evaluate import evaluate # noqa: F401
|
||||||
from .find_function import find_function # noqa: F401
|
from .find_function import find_function # noqa: F401
|
||||||
|
|
98
spacy/cli/distill.py
Normal file
98
spacy/cli/distill.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from wasabi import msg
|
||||||
|
|
||||||
|
from .. import util
|
||||||
|
from ..pipeline.trainable_pipe import TrainablePipe
|
||||||
|
from ..schemas import ConfigSchemaDistill
|
||||||
|
from ..training.initialize import init_nlp_student
|
||||||
|
from ..training.loop import distill as distill_nlp
|
||||||
|
from ._util import (
|
||||||
|
Arg,
|
||||||
|
Opt,
|
||||||
|
app,
|
||||||
|
import_code_paths,
|
||||||
|
parse_config_overrides,
|
||||||
|
setup_gpu,
|
||||||
|
show_validation_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command(
|
||||||
|
"distill",
|
||||||
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||||
|
)
|
||||||
|
def distill_cli(
|
||||||
|
# fmt: off
|
||||||
|
ctx: typer.Context, # This is only used to read additional arguments
|
||||||
|
teacher_model: str = Arg(..., help="Teacher model name or path"),
|
||||||
|
student_config_path: Path = Arg(..., help="Path to config file", exists=True, allow_dash=True),
|
||||||
|
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store trained pipeline in"),
|
||||||
|
code_path: str = Opt("", "--code", "-c", help="Comma-separated paths to Python files with additional code (registered functions) to be imported"),
|
||||||
|
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||||
|
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU")
|
||||||
|
# fmt: on
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Distill a spaCy pipeline from a teacher model.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/cli#distill
|
||||||
|
"""
|
||||||
|
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||||
|
overrides = parse_config_overrides(ctx.args)
|
||||||
|
import_code_paths(code_path)
|
||||||
|
distill(
|
||||||
|
teacher_model,
|
||||||
|
student_config_path,
|
||||||
|
output_path,
|
||||||
|
use_gpu=use_gpu,
|
||||||
|
overrides=overrides,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def distill(
|
||||||
|
teacher_model: Union[str, Path],
|
||||||
|
student_config_path: Union[str, Path],
|
||||||
|
output_path: Optional[Union[str, Path]] = None,
|
||||||
|
*,
|
||||||
|
use_gpu: int = -1,
|
||||||
|
overrides: Dict[str, Any] = util.SimpleFrozenDict(),
|
||||||
|
):
|
||||||
|
student_config_path = util.ensure_path(student_config_path)
|
||||||
|
output_path = util.ensure_path(output_path)
|
||||||
|
# Make sure all files and paths exist if they are needed
|
||||||
|
if not student_config_path or (
|
||||||
|
str(student_config_path) != "-" and not student_config_path.exists()
|
||||||
|
):
|
||||||
|
msg.fail("Student config file not found", student_config_path, exits=1)
|
||||||
|
if not output_path:
|
||||||
|
msg.info("No output directory provided")
|
||||||
|
else:
|
||||||
|
if not output_path.exists():
|
||||||
|
output_path.mkdir(parents=True)
|
||||||
|
msg.good(f"Created output directory: {output_path}")
|
||||||
|
msg.info(f"Saving to output directory: {output_path}")
|
||||||
|
setup_gpu(use_gpu)
|
||||||
|
teacher = util.load_model(teacher_model)
|
||||||
|
with show_validation_error(student_config_path):
|
||||||
|
config = util.load_config(
|
||||||
|
student_config_path, overrides=overrides, interpolate=False
|
||||||
|
)
|
||||||
|
msg.divider("Initializing student pipeline")
|
||||||
|
with show_validation_error(student_config_path, hint_fill=False):
|
||||||
|
student = init_nlp_student(config, teacher, use_gpu=use_gpu)
|
||||||
|
|
||||||
|
msg.good("Initialized student pipeline")
|
||||||
|
msg.divider("Distilling student pipeline from teacher")
|
||||||
|
distill_nlp(
|
||||||
|
teacher,
|
||||||
|
student,
|
||||||
|
output_path,
|
||||||
|
use_gpu=use_gpu,
|
||||||
|
stdout=sys.stdout,
|
||||||
|
stderr=sys.stderr,
|
||||||
|
)
|
|
@ -52,6 +52,7 @@ def test_convert_auto_conflict():
|
||||||
NOOP_CONFIG = """
|
NOOP_CONFIG = """
|
||||||
[paths]
|
[paths]
|
||||||
train = null
|
train = null
|
||||||
|
distill = null
|
||||||
dev = null
|
dev = null
|
||||||
vectors = null
|
vectors = null
|
||||||
init_tok2vec = null
|
init_tok2vec = null
|
||||||
|
@ -96,6 +97,14 @@ max_length = 0
|
||||||
limit = 0
|
limit = 0
|
||||||
augmenter = null
|
augmenter = null
|
||||||
|
|
||||||
|
[corpora.distill]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths.distill}
|
||||||
|
gold_preproc = false
|
||||||
|
max_length = 0
|
||||||
|
limit = 0
|
||||||
|
augmenter = null
|
||||||
|
|
||||||
[training]
|
[training]
|
||||||
seed = ${system.seed}
|
seed = ${system.seed}
|
||||||
gpu_allocator = ${system.gpu_allocator}
|
gpu_allocator = ${system.gpu_allocator}
|
||||||
|
@ -143,6 +152,37 @@ learn_rate = 0.001
|
||||||
|
|
||||||
[training.score_weights]
|
[training.score_weights]
|
||||||
|
|
||||||
|
[distillation]
|
||||||
|
student_to_teacher = []
|
||||||
|
dropout = 0.1
|
||||||
|
max_epochs = 0
|
||||||
|
max_steps = 100
|
||||||
|
corpus = "corpora.distill"
|
||||||
|
|
||||||
|
[distillation.batcher]
|
||||||
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
|
discard_oversize = false
|
||||||
|
tolerance = 0.2
|
||||||
|
get_length = null
|
||||||
|
|
||||||
|
[distillation.batcher.size]
|
||||||
|
@schedules = "compounding.v1"
|
||||||
|
start = 100
|
||||||
|
stop = 1000
|
||||||
|
compound = 1.001
|
||||||
|
t = 0.0
|
||||||
|
|
||||||
|
[distillation.optimizer]
|
||||||
|
@optimizers = "Adam.v1"
|
||||||
|
beta1 = 0.9
|
||||||
|
beta2 = 0.999
|
||||||
|
L2_is_weight_decay = true
|
||||||
|
L2 = 0.01
|
||||||
|
grad_clip = 1.0
|
||||||
|
use_averages = false
|
||||||
|
eps = 0.00000001
|
||||||
|
learn_rate = 0.001
|
||||||
|
|
||||||
[pretraining]
|
[pretraining]
|
||||||
|
|
||||||
[initialize]
|
[initialize]
|
||||||
|
@ -172,6 +212,8 @@ def data_paths():
|
||||||
db.to_disk(fpath)
|
db.to_disk(fpath)
|
||||||
|
|
||||||
args = [
|
args = [
|
||||||
|
"--paths.distill",
|
||||||
|
str(fpath),
|
||||||
"--paths.train",
|
"--paths.train",
|
||||||
str(fpath),
|
str(fpath),
|
||||||
"--paths.dev",
|
"--paths.dev",
|
||||||
|
@ -211,21 +253,33 @@ def noop_config():
|
||||||
yield cfg
|
yield cfg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def noop_model():
|
||||||
|
with make_tempdir() as temp_d:
|
||||||
|
nlp = spacy.blank("en")
|
||||||
|
path = temp_d / "noop-model"
|
||||||
|
nlp.to_disk(path)
|
||||||
|
yield path
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"cmd",
|
"cmd",
|
||||||
["debug config", "debug data", "train", "assemble"],
|
["debug config", "debug data", "train", "assemble", "distill"],
|
||||||
)
|
)
|
||||||
def test_multi_code(cmd, code_paths, data_paths, noop_config):
|
def test_multi_code(cmd, code_paths, data_paths, noop_config, noop_model):
|
||||||
# check that it fails without the code arg
|
# check that it fails without the code arg
|
||||||
cmd = cmd.split()
|
cmd = cmd.split()
|
||||||
output = ["."] if cmd[0] == "assemble" else []
|
output = ["."] if cmd[0] == "assemble" else []
|
||||||
|
model = [str(noop_model)] if cmd[0] == "distill" else []
|
||||||
cmd = [sys.executable, "-m", "spacy"] + cmd
|
cmd = [sys.executable, "-m", "spacy"] + cmd
|
||||||
result = subprocess.run([*cmd, str(noop_config), *output, *data_paths])
|
result = subprocess.run([*cmd, *model, str(noop_config), *output, *data_paths])
|
||||||
assert result.returncode == 1
|
assert result.returncode == 1
|
||||||
|
|
||||||
# check that it succeeds with the code arg
|
# check that it succeeds with the code arg
|
||||||
result = subprocess.run([*cmd, str(noop_config), *output, *data_paths, *code_paths])
|
result = subprocess.run(
|
||||||
|
[*cmd, *model, str(noop_config), *output, *data_paths, *code_paths]
|
||||||
|
)
|
||||||
assert result.returncode == 0
|
assert result.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ menu:
|
||||||
- ['assemble', 'assemble']
|
- ['assemble', 'assemble']
|
||||||
- ['package', 'package']
|
- ['package', 'package']
|
||||||
- ['project', 'project']
|
- ['project', 'project']
|
||||||
|
- ['distill', 'distill']
|
||||||
- ['huggingface-hub', 'huggingface-hub']
|
- ['huggingface-hub', 'huggingface-hub']
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@ -1296,8 +1297,10 @@ input formats are:
|
||||||
|
|
||||||
When a directory is provided it is traversed recursively to collect all files.
|
When a directory is provided it is traversed recursively to collect all files.
|
||||||
|
|
||||||
When loading a .spacy file, any potential annotations stored on the `Doc` that are not overwritten by the pipeline will be preserved.
|
When loading a .spacy file, any potential annotations stored on the `Doc` that
|
||||||
If you want to evaluate the pipeline on raw text only, make sure that the .spacy file does not contain any annotations.
|
are not overwritten by the pipeline will be preserved. If you want to evaluate
|
||||||
|
the pipeline on raw text only, make sure that the .spacy file does not contain
|
||||||
|
any annotations.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ python -m spacy apply [model] [data-path] [output-file] [--code] [--text-key] [--force-overwrite] [--gpu-id] [--batch-size] [--n-process]
|
$ python -m spacy apply [model] [data-path] [output-file] [--code] [--text-key] [--force-overwrite] [--gpu-id] [--batch-size] [--n-process]
|
||||||
|
@ -1699,6 +1702,45 @@ $ python -m spacy project dvc [project_dir] [workflow] [--force] [--verbose] [--
|
||||||
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||||
| **CREATES** | A `dvc.yaml` file in the project directory, based on the steps defined in the given workflow. |
|
| **CREATES** | A `dvc.yaml` file in the project directory, based on the steps defined in the given workflow. |
|
||||||
|
|
||||||
|
## distill {id="distill", tag="experimental", version="4"}
|
||||||
|
|
||||||
|
Distill a _student_ pipeline from a _teacher_ pipeline. Distillation trains the
|
||||||
|
models in the student pipeline on the activations of the teacher's models. A
|
||||||
|
typical use case for distillation is to extract a smaller, more performant model
|
||||||
|
from a larger high-accuracy model. Since distillation uses the activations of
|
||||||
|
the teacher, distillation can be performed on a corpus of raw text without (gold
|
||||||
|
standard) annotations. A development set of gold annotations _is_ needed to
|
||||||
|
evaluate the student pipeline on during distillation.
|
||||||
|
|
||||||
|
`distill` will save out the best performing pipeline across all epochs, as well
|
||||||
|
as the final pipeline. The `--code` argument can be used to provide a Python
|
||||||
|
file that's imported before the training process starts. This lets you register
|
||||||
|
[custom functions](/usage/training#custom-functions) and architectures and refer
|
||||||
|
to them in your config, all while still using spaCy's built-in `train` workflow.
|
||||||
|
If you need to manage complex multi-step training workflows, check out
|
||||||
|
[Weasel](https://github.com/explosion/weasel).
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```bash
|
||||||
|
> $ python -m spacy distill teacher-pipeline student.cfg --output ./output --paths.distill ./distill --paths.dev ./dev
|
||||||
|
> ```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ python -m spacy distill [teacher_model] [student_config_path] [--output] [--code] [--verbose] [--gpu-id] [overrides]
|
||||||
|
```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
|
| `teacher_model` | The teacher pipeline (name or path) to distill the student from. ~~Union[str, Path] (positional)~~ |
|
||||||
|
| `student_config_path` | The configuration of the student pipeline. ~~Path (positional)~~ |
|
||||||
|
| `--output`, `-o` | Directory to store the distilled pipeline in. Will be created if it doesn't exist. No pipeline will be saved when this option is absent. ~~Optional[Path] \(option)~~ |
|
||||||
|
| `--code`, `-c` | Comma-separated paths to Python files with additional code to be imported. Allows [registering custom functions](/usage/training#custom-functions) for new architectures. ~~Optional[Path] \(option)~~ |
|
||||||
|
| `--verbose`, `-V` | Show more detailed messages during distillation. ~~bool (flag)~~ |
|
||||||
|
| `--gpu-id`, `-g` | GPU ID or `-1` for CPU. Defaults to `-1`. ~~int (option)~~ |
|
||||||
|
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||||
|
| **CREATES** | The final distilled pipeline and the best distilled pipeline. |
|
||||||
|
|
||||||
## huggingface-hub {id="huggingface-hub",version="3.1"}
|
## huggingface-hub {id="huggingface-hub",version="3.1"}
|
||||||
|
|
||||||
The `spacy huggingface-cli` CLI includes commands for uploading your trained
|
The `spacy huggingface-cli` CLI includes commands for uploading your trained
|
||||||
|
|
Loading…
Reference in New Issue
Block a user