mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Add init fill-config
This commit is contained in:
parent
67cc39af7f
commit
fdcde9b0bf
|
@ -15,7 +15,7 @@ from .debug_model import debug_model # noqa: F401
|
||||||
from .evaluate import evaluate # noqa: F401
|
from .evaluate import evaluate # noqa: F401
|
||||||
from .convert import convert # noqa: F401
|
from .convert import convert # noqa: F401
|
||||||
from .init_model import init_model # noqa: F401
|
from .init_model import init_model # noqa: F401
|
||||||
from .init_config import init_config # noqa: F401
|
from .init_config import init_config, fill_config # noqa: F401
|
||||||
from .validate import validate # noqa: F401
|
from .validate import validate # noqa: F401
|
||||||
from .project.clone import project_clone # noqa: F401
|
from .project.clone import project_clone # noqa: F401
|
||||||
from .project.assets import project_assets # noqa: F401
|
from .project.assets import project_assets # noqa: F401
|
||||||
|
|
|
@ -179,13 +179,13 @@ def show_validation_error(
|
||||||
file_path: Optional[Union[str, Path]] = None,
|
file_path: Optional[Union[str, Path]] = None,
|
||||||
*,
|
*,
|
||||||
title: str = "Config validation error",
|
title: str = "Config validation error",
|
||||||
hint_init: bool = True,
|
hint_fill: bool = True,
|
||||||
):
|
):
|
||||||
"""Helper to show custom config validation errors on the CLI.
|
"""Helper to show custom config validation errors on the CLI.
|
||||||
|
|
||||||
file_path (str / Path): Optional file path of config file, used in hints.
|
file_path (str / Path): Optional file path of config file, used in hints.
|
||||||
title (str): Title of the custom formatted error.
|
title (str): Title of the custom formatted error.
|
||||||
hint_init (bool): Show hint about filling config.
|
hint_fill (bool): Show hint about filling config.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
|
@ -195,14 +195,14 @@ def show_validation_error(
|
||||||
# helper for this in Thinc
|
# helper for this in Thinc
|
||||||
err_text = str(e).replace("Config validation error", "").strip()
|
err_text = str(e).replace("Config validation error", "").strip()
|
||||||
print(err_text)
|
print(err_text)
|
||||||
if hint_init and "field required" in err_text:
|
if hint_fill and "field required" in err_text:
|
||||||
config_path = file_path if file_path is not None else "config.cfg"
|
config_path = file_path if file_path is not None else "config.cfg"
|
||||||
msg.text(
|
msg.text(
|
||||||
"If your config contains missing values, you can run the 'init "
|
"If your config contains missing values, you can run the 'init "
|
||||||
"config' command to fill in all the defaults, if possible:",
|
"fill-config' command to fill in all the defaults, if possible:",
|
||||||
spaced=True,
|
spaced=True,
|
||||||
)
|
)
|
||||||
print(f"{COMMAND} init config {config_path} --base {config_path}\n")
|
print(f"{COMMAND} init fill-config {config_path} --base {config_path}\n")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Tuple
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import Printer
|
from wasabi import Printer, diff_strings
|
||||||
|
from thinc.api import Config
|
||||||
import srsly
|
import srsly
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
@ -21,7 +22,6 @@ class Optimizations(str, Enum):
|
||||||
def init_config_cli(
|
def init_config_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True),
|
output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True),
|
||||||
# TODO: base_path: Optional[Path] = Opt(None, "--base", "-b", help="Optional base config to fill", exists=True, dir_okay=False),
|
|
||||||
lang: Optional[str] = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
|
lang: Optional[str] = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
|
||||||
pipeline: Optional[str] = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include in the model (without 'tok2vec' or 'transformer')"),
|
pipeline: Optional[str] = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include in the model (without 'tok2vec' or 'transformer')"),
|
||||||
optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
|
optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
|
||||||
|
@ -40,6 +40,46 @@ def init_config_cli(
|
||||||
init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu)
|
init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu)
|
||||||
|
|
||||||
|
|
||||||
|
@init_cli.command("fill-config")
|
||||||
|
def init_fill_config_cli(
|
||||||
|
# fmt: off
|
||||||
|
base_path: Path = Arg(..., help="Base config to fill", exists=True, dir_okay=False),
|
||||||
|
output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True),
|
||||||
|
diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes")
|
||||||
|
# fmt: on
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Fill partial config.cfg with default values. Will add all missing settings
|
||||||
|
from the default config and will create all objects, check the registered
|
||||||
|
functions for their default values and update the base config. This command
|
||||||
|
can be used with a config generated via the training quickstart widget:
|
||||||
|
https://nightly.spacy.io/usage/training#quickstart
|
||||||
|
"""
|
||||||
|
fill_config(output_file, base_path, diff=diff)
|
||||||
|
|
||||||
|
|
||||||
|
def fill_config(
|
||||||
|
output_file: Path, base_path: Path, *, diff: bool = False
|
||||||
|
) -> Tuple[Config, Config]:
|
||||||
|
is_stdout = str(output_file) == "-"
|
||||||
|
msg = Printer(no_print=is_stdout)
|
||||||
|
with show_validation_error(hint_fill=False):
|
||||||
|
with msg.loading("Auto-filling config..."):
|
||||||
|
config = util.load_config(base_path)
|
||||||
|
try:
|
||||||
|
nlp, _ = util.load_model_from_config(config, auto_fill=True)
|
||||||
|
except ValueError as e:
|
||||||
|
msg.fail(str(e), exits=1)
|
||||||
|
msg.good("Auto-filled config with all values")
|
||||||
|
if diff and not is_stdout:
|
||||||
|
msg.divider("START CONFIG DIFF")
|
||||||
|
print("")
|
||||||
|
print(diff_strings(config.to_str(), nlp.config.to_str()))
|
||||||
|
msg.divider("END CONFIG DIFF")
|
||||||
|
print("")
|
||||||
|
save_config(nlp.config, output_file, is_stdout=is_stdout)
|
||||||
|
|
||||||
|
|
||||||
def init_config(
|
def init_config(
|
||||||
output_file: Path, *, lang: str, pipeline: List[str], optimize: str, cpu: bool
|
output_file: Path, *, lang: str, pipeline: List[str], optimize: str, cpu: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -77,7 +117,7 @@ def init_config(
|
||||||
msg.good("Generated template specific for your use case:")
|
msg.good("Generated template specific for your use case:")
|
||||||
for label, value in use_case.items():
|
for label, value in use_case.items():
|
||||||
msg.text(f"- {label}: {value}")
|
msg.text(f"- {label}: {value}")
|
||||||
with show_validation_error(hint_init=False):
|
with show_validation_error(hint_fill=False):
|
||||||
with msg.loading("Auto-filling config..."):
|
with msg.loading("Auto-filling config..."):
|
||||||
config = util.load_config_from_str(base_template)
|
config = util.load_config_from_str(base_template)
|
||||||
try:
|
try:
|
||||||
|
@ -85,17 +125,22 @@ def init_config(
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
msg.fail(str(e), exits=1)
|
msg.fail(str(e), exits=1)
|
||||||
msg.good("Auto-filled config with all values")
|
msg.good("Auto-filled config with all values")
|
||||||
|
save_config(nlp.config, output_file, is_stdout=is_stdout)
|
||||||
|
|
||||||
|
|
||||||
|
def save_config(config: Config, output_file: Path, is_stdout: bool = False):
|
||||||
|
msg = Printer(no_print=is_stdout)
|
||||||
if is_stdout:
|
if is_stdout:
|
||||||
print(nlp.config.to_str())
|
print(config.to_str())
|
||||||
else:
|
else:
|
||||||
nlp.config.to_disk(output_file, interpolate=False)
|
config.to_disk(output_file, interpolate=False)
|
||||||
msg.good("Saved config", output_file)
|
msg.good("Saved config", output_file)
|
||||||
msg.text("You can now add your data and train your model:")
|
msg.text("You can now add your data and train your model:")
|
||||||
variables = ["--paths.train ./train.spacy", "--paths.dev ./dev.spacy"]
|
variables = ["--paths.train ./train.spacy", "--paths.dev ./dev.spacy"]
|
||||||
print(f"{COMMAND} train {output_file.parts[-1]} {' '.join(variables)}")
|
print(f"{COMMAND} train {output_file.parts[-1]} {' '.join(variables)}")
|
||||||
|
|
||||||
|
|
||||||
def require_spacy_transformers(msg):
|
def require_spacy_transformers(msg: Printer):
|
||||||
try:
|
try:
|
||||||
import spacy_transformers # noqa: F401
|
import spacy_transformers # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user