mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-14 13:47:13 +03:00
5fb8b7037a
Validate both `[initialize]` and `[training]` in `debug data` and `nlp.initialize()` with separate config validation error blocks that indicate which block of the config is being validated.
107 lines
4.4 KiB
Python
107 lines
4.4 KiB
Python
from typing import Optional, Dict, Any, Union, List
|
|
from pathlib import Path
|
|
from wasabi import msg, table
|
|
from thinc.api import Config
|
|
from thinc.config import VARIABLE_RE
|
|
import typer
|
|
|
|
from ._util import Arg, Opt, show_validation_error, parse_config_overrides
|
|
from ._util import import_code, debug_cli
|
|
from ..schemas import ConfigSchemaInit, ConfigSchemaTraining
|
|
from ..util import registry
|
|
from .. import util
|
|
|
|
|
|
@debug_cli.command(
|
|
"config",
|
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
|
)
|
|
def debug_config_cli(
|
|
# fmt: off
|
|
ctx: typer.Context, # This is only used to read additional arguments
|
|
config_path: Path = Arg(..., help="Path to config file", exists=True, allow_dash=True),
|
|
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
|
show_funcs: bool = Opt(False, "--show-functions", "-F", help="Show an overview of all registered functions used in the config and where they come from (modules, files etc.)"),
|
|
show_vars: bool = Opt(False, "--show-variables", "-V", help="Show an overview of all variables referenced in the config and their values. This will also reflect variables overwritten on the CLI.")
|
|
# fmt: on
|
|
):
|
|
"""Debug a config.cfg file and show validation errors. The command will
|
|
create all objects in the tree and validate them. Note that some config
|
|
validation errors are blocking and will prevent the rest of the config from
|
|
being resolved. This means that you may not see all validation errors at
|
|
once and some issues are only shown once previous errors have been fixed.
|
|
Similar as with the 'train' command, you can override settings from the config
|
|
as command line options. For instance, --training.batch_size 128 overrides
|
|
the value of "batch_size" in the block "[training]".
|
|
|
|
DOCS: https://nightly.spacy.io/api/cli#debug-config
|
|
"""
|
|
overrides = parse_config_overrides(ctx.args)
|
|
import_code(code_path)
|
|
debug_config(
|
|
config_path, overrides=overrides, show_funcs=show_funcs, show_vars=show_vars
|
|
)
|
|
|
|
|
|
def debug_config(
|
|
config_path: Path,
|
|
*,
|
|
overrides: Dict[str, Any] = {},
|
|
show_funcs: bool = False,
|
|
show_vars: bool = False,
|
|
):
|
|
msg.divider("Config validation")
|
|
with show_validation_error(config_path):
|
|
config = util.load_config(config_path, overrides=overrides)
|
|
nlp = util.load_model_from_config(config)
|
|
config = nlp.config.interpolate()
|
|
msg.divider("Config validation for [initialize]")
|
|
with show_validation_error(config_path):
|
|
T = registry.resolve(config["initialize"], schema=ConfigSchemaInit)
|
|
msg.divider("Config validation for [training]")
|
|
with show_validation_error(config_path):
|
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
|
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
|
util.resolve_dot_names(config, dot_names)
|
|
msg.good("Config is valid")
|
|
if show_vars:
|
|
variables = get_variables(config)
|
|
msg.divider(f"Variables ({len(variables)})")
|
|
head = ("Variable", "Value")
|
|
msg.table(variables, header=head, divider=True, widths=(41, 34), spacing=2)
|
|
if show_funcs:
|
|
funcs = get_registered_funcs(config)
|
|
msg.divider(f"Registered functions ({len(funcs)})")
|
|
for func in funcs:
|
|
func_data = {
|
|
"Registry": f"@{func['registry']}",
|
|
"Name": func["name"],
|
|
"Module": func["module"],
|
|
"File": f"{func['file']} (line {func['line_no']})",
|
|
}
|
|
msg.info(f"[{func['path']}]")
|
|
print(table(func_data).strip())
|
|
|
|
|
|
def get_registered_funcs(config: Config) -> List[Dict[str, Optional[Union[str, int]]]]:
|
|
result = []
|
|
for key, value in util.walk_dict(config):
|
|
if not key[-1].startswith("@"):
|
|
continue
|
|
# We have a reference to a registered function
|
|
reg_name = key[-1][1:]
|
|
registry = getattr(util.registry, reg_name)
|
|
path = ".".join(key[:-1])
|
|
info = registry.find(value)
|
|
result.append({"name": value, "registry": reg_name, "path": path, **info})
|
|
return result
|
|
|
|
|
|
def get_variables(config: Config) -> Dict[str, Any]:
|
|
result = {}
|
|
for variable in sorted(set(VARIABLE_RE.findall(config.to_str()))):
|
|
path = variable[2:-1].replace(":", ".")
|
|
value = util.dot_to_object(config, path)
|
|
result[variable] = repr(value)
|
|
return result
|