diff --git a/spacy/cli/__init__.py b/spacy/cli/__init__.py index bc47ffdef..2b21e2f2b 100644 --- a/spacy/cli/__init__.py +++ b/spacy/cli/__init__.py @@ -15,7 +15,7 @@ from .debug_model import debug_model # noqa: F401 from .evaluate import evaluate # noqa: F401 from .convert import convert # noqa: F401 from .init_model import init_model # noqa: F401 -from .init_config import init_config # noqa: F401 +from .init_config import init_config, fill_config # noqa: F401 from .validate import validate # noqa: F401 from .project.clone import project_clone # noqa: F401 from .project.assets import project_assets # noqa: F401 diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index 93ec9f31e..5613fa317 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -179,13 +179,13 @@ def show_validation_error( file_path: Optional[Union[str, Path]] = None, *, title: str = "Config validation error", - hint_init: bool = True, + hint_fill: bool = True, ): """Helper to show custom config validation errors on the CLI. file_path (str / Path): Optional file path of config file, used in hints. title (str): Title of the custom formatted error. - hint_init (bool): Show hint about filling config. + hint_fill (bool): Show hint about filling config. """ try: yield @@ -195,14 +195,14 @@ def show_validation_error( # helper for this in Thinc err_text = str(e).replace("Config validation error", "").strip() print(err_text) - if hint_init and "field required" in err_text: + if hint_fill and "field required" in err_text: config_path = file_path if file_path is not None else "config.cfg" msg.text( "If your config contains missing values, you can run the 'init " - "config' command to fill in all the defaults, if possible:", + "fill-config' command to fill in all the defaults, if possible:", spaced=True, ) - print(f"{COMMAND} init config {config_path} --base {config_path}\n") + print(f"{COMMAND} init fill-config {config_path} --base {config_path}\n") sys.exit(1) diff --git a/spacy/cli/init_config.py b/spacy/cli/init_config.py index cc4c980be..e4c068aa7 100644 --- a/spacy/cli/init_config.py +++ b/spacy/cli/init_config.py @@ -1,7 +1,8 @@ -from typing import Optional, List +from typing import Optional, List, Tuple from enum import Enum from pathlib import Path -from wasabi import Printer +from wasabi import Printer, diff_strings +from thinc.api import Config import srsly import re @@ -21,7 +22,6 @@ class Optimizations(str, Enum): def init_config_cli( # fmt: off output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True), - # TODO: base_path: Optional[Path] = Opt(None, "--base", "-b", help="Optional base config to fill", exists=True, dir_okay=False), lang: Optional[str] = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"), pipeline: Optional[str] = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include in the model (without 'tok2vec' or 'transformer')"), optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."), @@ -40,6 +40,46 @@ def init_config_cli( init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu) +@init_cli.command("fill-config") +def init_fill_config_cli( + # fmt: off + base_path: Path = Arg(..., help="Base config to fill", exists=True, dir_okay=False), + output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True), + diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes") + # fmt: on +): + """ + Fill partial config.cfg with default values. Will add all missing settings + from the default config and will create all objects, check the registered + functions for their default values and update the base config. This command + can be used with a config generated via the training quickstart widget: + https://nightly.spacy.io/usage/training#quickstart + """ + fill_config(output_file, base_path, diff=diff) + + +def fill_config( + output_file: Path, base_path: Path, *, diff: bool = False +) -> Tuple[Config, Config]: + is_stdout = str(output_file) == "-" + msg = Printer(no_print=is_stdout) + with show_validation_error(hint_fill=False): + with msg.loading("Auto-filling config..."): + config = util.load_config(base_path) + try: + nlp, _ = util.load_model_from_config(config, auto_fill=True) + except ValueError as e: + msg.fail(str(e), exits=1) + msg.good("Auto-filled config with all values") + if diff and not is_stdout: + msg.divider("START CONFIG DIFF") + print("") + print(diff_strings(config.to_str(), nlp.config.to_str())) + msg.divider("END CONFIG DIFF") + print("") + save_config(nlp.config, output_file, is_stdout=is_stdout) + + def init_config( output_file: Path, *, lang: str, pipeline: List[str], optimize: str, cpu: bool ) -> None: @@ -77,7 +117,7 @@ def init_config( msg.good("Generated template specific for your use case:") for label, value in use_case.items(): msg.text(f"- {label}: {value}") - with show_validation_error(hint_init=False): + with show_validation_error(hint_fill=False): with msg.loading("Auto-filling config..."): config = util.load_config_from_str(base_template) try: @@ -85,17 +125,22 @@ def init_config( except ValueError as e: msg.fail(str(e), exits=1) msg.good("Auto-filled config with all values") + save_config(nlp.config, output_file, is_stdout=is_stdout) + + +def save_config(config: Config, output_file: Path, is_stdout: bool = False): + msg = Printer(no_print=is_stdout) if is_stdout: - print(nlp.config.to_str()) + print(config.to_str()) else: - nlp.config.to_disk(output_file, interpolate=False) + config.to_disk(output_file, interpolate=False) msg.good("Saved config", output_file) msg.text("You can now add your data and train your model:") variables = ["--paths.train ./train.spacy", "--paths.dev ./dev.spacy"] print(f"{COMMAND} train {output_file.parts[-1]} {' '.join(variables)}") -def require_spacy_transformers(msg): +def require_spacy_transformers(msg: Printer): try: import spacy_transformers # noqa: F401 except ImportError: