Add init fill-config

This commit is contained in:
Ines Montani 2020-08-14 16:49:26 +02:00
parent 67cc39af7f
commit fdcde9b0bf
3 changed files with 58 additions and 13 deletions

View File

@ -15,7 +15,7 @@ from .debug_model import debug_model # noqa: F401
from .evaluate import evaluate # noqa: F401 from .evaluate import evaluate # noqa: F401
from .convert import convert # noqa: F401 from .convert import convert # noqa: F401
from .init_model import init_model # noqa: F401 from .init_model import init_model # noqa: F401
from .init_config import init_config # noqa: F401 from .init_config import init_config, fill_config # noqa: F401
from .validate import validate # noqa: F401 from .validate import validate # noqa: F401
from .project.clone import project_clone # noqa: F401 from .project.clone import project_clone # noqa: F401
from .project.assets import project_assets # noqa: F401 from .project.assets import project_assets # noqa: F401

View File

@ -179,13 +179,13 @@ def show_validation_error(
file_path: Optional[Union[str, Path]] = None, file_path: Optional[Union[str, Path]] = None,
*, *,
title: str = "Config validation error", title: str = "Config validation error",
hint_init: bool = True, hint_fill: bool = True,
): ):
"""Helper to show custom config validation errors on the CLI. """Helper to show custom config validation errors on the CLI.
file_path (str / Path): Optional file path of config file, used in hints. file_path (str / Path): Optional file path of config file, used in hints.
title (str): Title of the custom formatted error. title (str): Title of the custom formatted error.
hint_init (bool): Show hint about filling config. hint_fill (bool): Show hint about filling config.
""" """
try: try:
yield yield
@ -195,14 +195,14 @@ def show_validation_error(
# helper for this in Thinc # helper for this in Thinc
err_text = str(e).replace("Config validation error", "").strip() err_text = str(e).replace("Config validation error", "").strip()
print(err_text) print(err_text)
if hint_init and "field required" in err_text: if hint_fill and "field required" in err_text:
config_path = file_path if file_path is not None else "config.cfg" config_path = file_path if file_path is not None else "config.cfg"
msg.text( msg.text(
"If your config contains missing values, you can run the 'init " "If your config contains missing values, you can run the 'init "
"config' command to fill in all the defaults, if possible:", "fill-config' command to fill in all the defaults, if possible:",
spaced=True, spaced=True,
) )
print(f"{COMMAND} init config {config_path} --base {config_path}\n") print(f"{COMMAND} init fill-config {config_path} --base {config_path}\n")
sys.exit(1) sys.exit(1)

View File

@ -1,7 +1,8 @@
from typing import Optional, List from typing import Optional, List, Tuple
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from wasabi import Printer from wasabi import Printer, diff_strings
from thinc.api import Config
import srsly import srsly
import re import re
@ -21,7 +22,6 @@ class Optimizations(str, Enum):
def init_config_cli( def init_config_cli(
# fmt: off # fmt: off
output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True), output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True),
# TODO: base_path: Optional[Path] = Opt(None, "--base", "-b", help="Optional base config to fill", exists=True, dir_okay=False),
lang: Optional[str] = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"), lang: Optional[str] = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
pipeline: Optional[str] = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include in the model (without 'tok2vec' or 'transformer')"), pipeline: Optional[str] = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include in the model (without 'tok2vec' or 'transformer')"),
optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."), optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
@ -40,6 +40,46 @@ def init_config_cli(
init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu) init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu)
@init_cli.command("fill-config")
def init_fill_config_cli(
# fmt: off
base_path: Path = Arg(..., help="Base config to fill", exists=True, dir_okay=False),
output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True),
diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes")
# fmt: on
):
"""
Fill partial config.cfg with default values. Will add all missing settings
from the default config and will create all objects, check the registered
functions for their default values and update the base config. This command
can be used with a config generated via the training quickstart widget:
https://nightly.spacy.io/usage/training#quickstart
"""
fill_config(output_file, base_path, diff=diff)
def fill_config(
output_file: Path, base_path: Path, *, diff: bool = False
) -> Tuple[Config, Config]:
is_stdout = str(output_file) == "-"
msg = Printer(no_print=is_stdout)
with show_validation_error(hint_fill=False):
with msg.loading("Auto-filling config..."):
config = util.load_config(base_path)
try:
nlp, _ = util.load_model_from_config(config, auto_fill=True)
except ValueError as e:
msg.fail(str(e), exits=1)
msg.good("Auto-filled config with all values")
if diff and not is_stdout:
msg.divider("START CONFIG DIFF")
print("")
print(diff_strings(config.to_str(), nlp.config.to_str()))
msg.divider("END CONFIG DIFF")
print("")
save_config(nlp.config, output_file, is_stdout=is_stdout)
def init_config( def init_config(
output_file: Path, *, lang: str, pipeline: List[str], optimize: str, cpu: bool output_file: Path, *, lang: str, pipeline: List[str], optimize: str, cpu: bool
) -> None: ) -> None:
@ -77,7 +117,7 @@ def init_config(
msg.good("Generated template specific for your use case:") msg.good("Generated template specific for your use case:")
for label, value in use_case.items(): for label, value in use_case.items():
msg.text(f"- {label}: {value}") msg.text(f"- {label}: {value}")
with show_validation_error(hint_init=False): with show_validation_error(hint_fill=False):
with msg.loading("Auto-filling config..."): with msg.loading("Auto-filling config..."):
config = util.load_config_from_str(base_template) config = util.load_config_from_str(base_template)
try: try:
@ -85,17 +125,22 @@ def init_config(
except ValueError as e: except ValueError as e:
msg.fail(str(e), exits=1) msg.fail(str(e), exits=1)
msg.good("Auto-filled config with all values") msg.good("Auto-filled config with all values")
save_config(nlp.config, output_file, is_stdout=is_stdout)
def save_config(config: Config, output_file: Path, is_stdout: bool = False):
msg = Printer(no_print=is_stdout)
if is_stdout: if is_stdout:
print(nlp.config.to_str()) print(config.to_str())
else: else:
nlp.config.to_disk(output_file, interpolate=False) config.to_disk(output_file, interpolate=False)
msg.good("Saved config", output_file) msg.good("Saved config", output_file)
msg.text("You can now add your data and train your model:") msg.text("You can now add your data and train your model:")
variables = ["--paths.train ./train.spacy", "--paths.dev ./dev.spacy"] variables = ["--paths.train ./train.spacy", "--paths.dev ./dev.spacy"]
print(f"{COMMAND} train {output_file.parts[-1]} {' '.join(variables)}") print(f"{COMMAND} train {output_file.parts[-1]} {' '.join(variables)}")
def require_spacy_transformers(msg): def require_spacy_transformers(msg: Printer):
try: try:
import spacy_transformers # noqa: F401 import spacy_transformers # noqa: F401
except ImportError: except ImportError: