mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-28 02:46:35 +03:00
fbc14aea45
* Add distill subcommand This subcommand distills a student model from a teacher model. * Fixes from Sofie Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Type and doc fixes * Wording * distill: document missing `-o` * Wording * Small fix --------- Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional, Union
|
|
|
|
import typer
|
|
from wasabi import msg
|
|
|
|
from .. import util
|
|
from ..pipeline.trainable_pipe import TrainablePipe
|
|
from ..schemas import ConfigSchemaDistill
|
|
from ..training.initialize import init_nlp_student
|
|
from ..training.loop import distill as distill_nlp
|
|
from ._util import (
|
|
Arg,
|
|
Opt,
|
|
app,
|
|
import_code_paths,
|
|
parse_config_overrides,
|
|
setup_gpu,
|
|
show_validation_error,
|
|
)
|
|
|
|
|
|
@app.command(
|
|
"distill",
|
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
|
)
|
|
def distill_cli(
|
|
# fmt: off
|
|
ctx: typer.Context, # This is only used to read additional arguments
|
|
teacher_model: str = Arg(..., help="Teacher model name or path"),
|
|
student_config_path: Path = Arg(..., help="Path to config file", exists=True, allow_dash=True),
|
|
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store trained pipeline in"),
|
|
code_path: str = Opt("", "--code", "-c", help="Comma-separated paths to Python files with additional code (registered functions) to be imported"),
|
|
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
|
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU")
|
|
# fmt: on
|
|
):
|
|
"""
|
|
Distill a spaCy pipeline from a teacher model.
|
|
|
|
DOCS: https://spacy.io/api/cli#distill
|
|
"""
|
|
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
|
overrides = parse_config_overrides(ctx.args)
|
|
import_code_paths(code_path)
|
|
distill(
|
|
teacher_model,
|
|
student_config_path,
|
|
output_path,
|
|
use_gpu=use_gpu,
|
|
overrides=overrides,
|
|
)
|
|
|
|
|
|
def distill(
|
|
teacher_model: Union[str, Path],
|
|
student_config_path: Union[str, Path],
|
|
output_path: Optional[Union[str, Path]] = None,
|
|
*,
|
|
use_gpu: int = -1,
|
|
overrides: Dict[str, Any] = util.SimpleFrozenDict(),
|
|
):
|
|
student_config_path = util.ensure_path(student_config_path)
|
|
output_path = util.ensure_path(output_path)
|
|
# Make sure all files and paths exist if they are needed
|
|
if not student_config_path or (
|
|
str(student_config_path) != "-" and not student_config_path.exists()
|
|
):
|
|
msg.fail("Student config file not found", student_config_path, exits=1)
|
|
if not output_path:
|
|
msg.info("No output directory provided")
|
|
else:
|
|
if not output_path.exists():
|
|
output_path.mkdir(parents=True)
|
|
msg.good(f"Created output directory: {output_path}")
|
|
msg.info(f"Saving to output directory: {output_path}")
|
|
setup_gpu(use_gpu)
|
|
teacher = util.load_model(teacher_model)
|
|
with show_validation_error(student_config_path):
|
|
config = util.load_config(
|
|
student_config_path, overrides=overrides, interpolate=False
|
|
)
|
|
msg.divider("Initializing student pipeline")
|
|
with show_validation_error(student_config_path, hint_fill=False):
|
|
student = init_nlp_student(config, teacher, use_gpu=use_gpu)
|
|
|
|
msg.good("Initialized student pipeline")
|
|
msg.divider("Distilling student pipeline from teacher")
|
|
distill_nlp(
|
|
teacher,
|
|
student,
|
|
output_path,
|
|
use_gpu=use_gpu,
|
|
stdout=sys.stdout,
|
|
stderr=sys.stderr,
|
|
)
|