Add init fill-config

This commit is contained in:
Ines Montani 2020-08-14 16:49:26 +02:00
parent 67cc39af7f
commit fdcde9b0bf
3 changed files with 58 additions and 13 deletions

View File

@ -15,7 +15,7 @@ from .debug_model import debug_model # noqa: F401
from .evaluate import evaluate # noqa: F401
from .convert import convert # noqa: F401
from .init_model import init_model # noqa: F401
from .init_config import init_config # noqa: F401
from .init_config import init_config, fill_config # noqa: F401
from .validate import validate # noqa: F401
from .project.clone import project_clone # noqa: F401
from .project.assets import project_assets # noqa: F401

View File

@ -179,13 +179,13 @@ def show_validation_error(
file_path: Optional[Union[str, Path]] = None,
*,
title: str = "Config validation error",
hint_init: bool = True,
hint_fill: bool = True,
):
"""Helper to show custom config validation errors on the CLI.
file_path (str / Path): Optional file path of config file, used in hints.
title (str): Title of the custom formatted error.
hint_init (bool): Show hint about filling config.
hint_fill (bool): Show hint about filling config.
"""
try:
yield
@ -195,14 +195,14 @@ def show_validation_error(
# helper for this in Thinc
err_text = str(e).replace("Config validation error", "").strip()
print(err_text)
if hint_init and "field required" in err_text:
if hint_fill and "field required" in err_text:
config_path = file_path if file_path is not None else "config.cfg"
msg.text(
"If your config contains missing values, you can run the 'init "
"config' command to fill in all the defaults, if possible:",
"fill-config' command to fill in all the defaults, if possible:",
spaced=True,
)
print(f"{COMMAND} init config {config_path} --base {config_path}\n")
print(f"{COMMAND} init fill-config {config_path} --base {config_path}\n")
sys.exit(1)

View File

@ -1,7 +1,8 @@
from typing import Optional, List
from typing import Optional, List, Tuple
from enum import Enum
from pathlib import Path
from wasabi import Printer
from wasabi import Printer, diff_strings
from thinc.api import Config
import srsly
import re
@ -21,7 +22,6 @@ class Optimizations(str, Enum):
def init_config_cli(
# fmt: off
output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True),
# TODO: base_path: Optional[Path] = Opt(None, "--base", "-b", help="Optional base config to fill", exists=True, dir_okay=False),
lang: Optional[str] = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
pipeline: Optional[str] = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include in the model (without 'tok2vec' or 'transformer')"),
optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
@ -40,6 +40,46 @@ def init_config_cli(
init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu)
@init_cli.command("fill-config")
def init_fill_config_cli(
# fmt: off
base_path: Path = Arg(..., help="Base config to fill", exists=True, dir_okay=False),
output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True),
diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes")
# fmt: on
):
"""
Fill partial config.cfg with default values. Will add all missing settings
from the default config and will create all objects, check the registered
functions for their default values and update the base config. This command
can be used with a config generated via the training quickstart widget:
https://nightly.spacy.io/usage/training#quickstart
"""
fill_config(output_file, base_path, diff=diff)
def fill_config(
output_file: Path, base_path: Path, *, diff: bool = False
) -> Tuple[Config, Config]:
is_stdout = str(output_file) == "-"
msg = Printer(no_print=is_stdout)
with show_validation_error(hint_fill=False):
with msg.loading("Auto-filling config..."):
config = util.load_config(base_path)
try:
nlp, _ = util.load_model_from_config(config, auto_fill=True)
except ValueError as e:
msg.fail(str(e), exits=1)
msg.good("Auto-filled config with all values")
if diff and not is_stdout:
msg.divider("START CONFIG DIFF")
print("")
print(diff_strings(config.to_str(), nlp.config.to_str()))
msg.divider("END CONFIG DIFF")
print("")
save_config(nlp.config, output_file, is_stdout=is_stdout)
def init_config(
output_file: Path, *, lang: str, pipeline: List[str], optimize: str, cpu: bool
) -> None:
@ -77,7 +117,7 @@ def init_config(
msg.good("Generated template specific for your use case:")
for label, value in use_case.items():
msg.text(f"- {label}: {value}")
with show_validation_error(hint_init=False):
with show_validation_error(hint_fill=False):
with msg.loading("Auto-filling config..."):
config = util.load_config_from_str(base_template)
try:
@ -85,17 +125,22 @@ def init_config(
except ValueError as e:
msg.fail(str(e), exits=1)
msg.good("Auto-filled config with all values")
save_config(nlp.config, output_file, is_stdout=is_stdout)
def save_config(config: Config, output_file: Path, is_stdout: bool = False):
msg = Printer(no_print=is_stdout)
if is_stdout:
print(nlp.config.to_str())
print(config.to_str())
else:
nlp.config.to_disk(output_file, interpolate=False)
config.to_disk(output_file, interpolate=False)
msg.good("Saved config", output_file)
msg.text("You can now add your data and train your model:")
variables = ["--paths.train ./train.spacy", "--paths.dev ./dev.spacy"]
print(f"{COMMAND} train {output_file.parts[-1]} {' '.join(variables)}")
def require_spacy_transformers(msg):
def require_spacy_transformers(msg: Printer):
try:
import spacy_transformers # noqa: F401
except ImportError: