mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-14 05:37:03 +03:00
107 lines
4.4 KiB
Python
107 lines
4.4 KiB
Python
from typing import Optional, Dict, Any, Union, List
|
|
from pathlib import Path
|
|
from wasabi import msg, table
|
|
from thinc.api import Config
|
|
from thinc.config import VARIABLE_RE
|
|
import typer
|
|
|
|
from ._util import Arg, Opt, show_validation_error, parse_config_overrides
|
|
from ._util import import_code, debug_cli
|
|
from ..schemas import ConfigSchemaInit, ConfigSchemaTraining
|
|
from ..util import registry
|
|
from .. import util
|
|
|
|
|
|
@debug_cli.command(
|
|
"config",
|
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
|
)
|
|
def debug_config_cli(
|
|
# fmt: off
|
|
ctx: typer.Context, # This is only used to read additional arguments
|
|
config_path: Path = Arg(..., help="Path to config file", exists=True, allow_dash=True),
|
|
code_path: Optional[Path] = Opt(None, "--code-path", "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
|
show_funcs: bool = Opt(False, "--show-functions", "-F", help="Show an overview of all registered functions used in the config and where they come from (modules, files etc.)"),
|
|
show_vars: bool = Opt(False, "--show-variables", "-V", help="Show an overview of all variables referenced in the config and their values. This will also reflect variables overwritten on the CLI.")
|
|
# fmt: on
|
|
):
|
|
"""Debug a config.cfg file and show validation errors. The command will
|
|
create all objects in the tree and validate them. Note that some config
|
|
validation errors are blocking and will prevent the rest of the config from
|
|
being resolved. This means that you may not see all validation errors at
|
|
once and some issues are only shown once previous errors have been fixed.
|
|
Similar as with the 'train' command, you can override settings from the config
|
|
as command line options. For instance, --training.batch_size 128 overrides
|
|
the value of "batch_size" in the block "[training]".
|
|
|
|
DOCS: https://spacy.io/api/cli#debug-config
|
|
"""
|
|
overrides = parse_config_overrides(ctx.args)
|
|
import_code(code_path)
|
|
debug_config(
|
|
config_path, overrides=overrides, show_funcs=show_funcs, show_vars=show_vars
|
|
)
|
|
|
|
|
|
def debug_config(
|
|
config_path: Path,
|
|
*,
|
|
overrides: Dict[str, Any] = {},
|
|
show_funcs: bool = False,
|
|
show_vars: bool = False,
|
|
):
|
|
msg.divider("Config validation")
|
|
with show_validation_error(config_path):
|
|
config = util.load_config(config_path, overrides=overrides)
|
|
nlp = util.load_model_from_config(config)
|
|
config = nlp.config.interpolate()
|
|
msg.divider("Config validation for [initialize]")
|
|
with show_validation_error(config_path):
|
|
T = registry.resolve(config["initialize"], schema=ConfigSchemaInit)
|
|
msg.divider("Config validation for [training]")
|
|
with show_validation_error(config_path):
|
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
|
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
|
util.resolve_dot_names(config, dot_names)
|
|
msg.good("Config is valid")
|
|
if show_vars:
|
|
variables = get_variables(config)
|
|
msg.divider(f"Variables ({len(variables)})")
|
|
head = ("Variable", "Value")
|
|
msg.table(variables, header=head, divider=True, widths=(41, 34), spacing=2)
|
|
if show_funcs:
|
|
funcs = get_registered_funcs(config)
|
|
msg.divider(f"Registered functions ({len(funcs)})")
|
|
for func in funcs:
|
|
func_data = {
|
|
"Registry": f"@{func['registry']}",
|
|
"Name": func["name"],
|
|
"Module": func["module"],
|
|
"File": f"{func['file']} (line {func['line_no']})",
|
|
}
|
|
msg.info(f"[{func['path']}]")
|
|
print(table(func_data).strip())
|
|
|
|
|
|
def get_registered_funcs(config: Config) -> List[Dict[str, Optional[Union[str, int]]]]:
|
|
result = []
|
|
for key, value in util.walk_dict(config):
|
|
if not key[-1].startswith("@"):
|
|
continue
|
|
# We have a reference to a registered function
|
|
reg_name = key[-1][1:]
|
|
registry = getattr(util.registry, reg_name)
|
|
path = ".".join(key[:-1])
|
|
info = registry.find(value)
|
|
result.append({"name": value, "registry": reg_name, "path": path, **info})
|
|
return result
|
|
|
|
|
|
def get_variables(config: Config) -> Dict[str, Any]:
|
|
result = {}
|
|
for variable in sorted(set(VARIABLE_RE.findall(config.to_str()))):
|
|
path = variable[2:-1].replace(":", ".")
|
|
value = util.dot_to_object(config, path)
|
|
result[variable] = repr(value)
|
|
return result
|