mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	* Use isort with Black profile * isort all the things * Fix import cycles as a result of import sorting * Add DOCBIN_ALL_ATTRS type definition * Add isort to requirements * Remove isort from build dependencies check * Typo
		
			
				
	
	
		
			114 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			114 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from pathlib import Path
 | |
| from typing import Any, Dict, List, Optional, Union
 | |
| 
 | |
| import typer
 | |
| from thinc.api import Config
 | |
| from thinc.config import VARIABLE_RE
 | |
| from wasabi import msg, table
 | |
| 
 | |
| from .. import util
 | |
| from ..schemas import ConfigSchemaInit, ConfigSchemaTraining
 | |
| from ..util import registry
 | |
| from ._util import (
 | |
|     Arg,
 | |
|     Opt,
 | |
|     debug_cli,
 | |
|     import_code,
 | |
|     parse_config_overrides,
 | |
|     show_validation_error,
 | |
| )
 | |
| 
 | |
| 
 | |
| @debug_cli.command(
 | |
|     "config",
 | |
|     context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
 | |
| )
 | |
| def debug_config_cli(
 | |
|     # fmt: off
 | |
|     ctx: typer.Context,  # This is only used to read additional arguments
 | |
|     config_path: Path = Arg(..., help="Path to config file", exists=True, allow_dash=True),
 | |
|     code_path: Optional[Path] = Opt(None, "--code-path", "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
 | |
|     show_funcs: bool = Opt(False, "--show-functions", "-F", help="Show an overview of all registered functions used in the config and where they come from (modules, files etc.)"),
 | |
|     show_vars: bool = Opt(False, "--show-variables", "-V", help="Show an overview of all variables referenced in the config and their values. This will also reflect variables overwritten on the CLI.")
 | |
|     # fmt: on
 | |
| ):
 | |
|     """Debug a config file and show validation errors. The command will
 | |
|     create all objects in the tree and validate them. Note that some config
 | |
|     validation errors are blocking and will prevent the rest of the config from
 | |
|     being resolved. This means that you may not see all validation errors at
 | |
|     once and some issues are only shown once previous errors have been fixed.
 | |
|     Similar as with the 'train' command, you can override settings from the config
 | |
|     as command line options. For instance, --training.batch_size 128 overrides
 | |
|     the value of "batch_size" in the block "[training]".
 | |
| 
 | |
|     DOCS: https://spacy.io/api/cli#debug-config
 | |
|     """
 | |
|     overrides = parse_config_overrides(ctx.args)
 | |
|     import_code(code_path)
 | |
|     debug_config(
 | |
|         config_path, overrides=overrides, show_funcs=show_funcs, show_vars=show_vars
 | |
|     )
 | |
| 
 | |
| 
 | |
| def debug_config(
 | |
|     config_path: Path,
 | |
|     *,
 | |
|     overrides: Dict[str, Any] = {},
 | |
|     show_funcs: bool = False,
 | |
|     show_vars: bool = False,
 | |
| ):
 | |
|     msg.divider("Config validation")
 | |
|     with show_validation_error(config_path):
 | |
|         config = util.load_config(config_path, overrides=overrides)
 | |
|         nlp = util.load_model_from_config(config)
 | |
|         config = nlp.config.interpolate()
 | |
|     msg.divider("Config validation for [initialize]")
 | |
|     with show_validation_error(config_path):
 | |
|         T = registry.resolve(config["initialize"], schema=ConfigSchemaInit)
 | |
|     msg.divider("Config validation for [training]")
 | |
|     with show_validation_error(config_path):
 | |
|         T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
 | |
|         dot_names = [T["train_corpus"], T["dev_corpus"]]
 | |
|         util.resolve_dot_names(config, dot_names)
 | |
|     msg.good("Config is valid")
 | |
|     if show_vars:
 | |
|         variables = get_variables(config)
 | |
|         msg.divider(f"Variables ({len(variables)})")
 | |
|         head = ("Variable", "Value")
 | |
|         msg.table(variables, header=head, divider=True, widths=(41, 34), spacing=2)
 | |
|     if show_funcs:
 | |
|         funcs = get_registered_funcs(config)
 | |
|         msg.divider(f"Registered functions ({len(funcs)})")
 | |
|         for func in funcs:
 | |
|             func_data = {
 | |
|                 "Registry": f"@{func['registry']}",
 | |
|                 "Name": func["name"],
 | |
|                 "Module": func["module"],
 | |
|                 "File": f"{func['file']} (line {func['line_no']})",
 | |
|             }
 | |
|             msg.info(f"[{func['path']}]")
 | |
|             print(table(func_data).strip())
 | |
| 
 | |
| 
 | |
| def get_registered_funcs(config: Config) -> List[Dict[str, Optional[Union[str, int]]]]:
 | |
|     result = []
 | |
|     for key, value in util.walk_dict(config):
 | |
|         if not key[-1].startswith("@"):
 | |
|             continue
 | |
|         # We have a reference to a registered function
 | |
|         reg_name = key[-1][1:]
 | |
|         registry = getattr(util.registry, reg_name)
 | |
|         path = ".".join(key[:-1])
 | |
|         info = registry.find(value)
 | |
|         result.append({"name": value, "registry": reg_name, "path": path, **info})
 | |
|     return result
 | |
| 
 | |
| 
 | |
| def get_variables(config: Config) -> Dict[str, Any]:
 | |
|     result = {}
 | |
|     for variable in sorted(set(VARIABLE_RE.findall(config.to_str()))):
 | |
|         path = variable[2:-1].replace(":", ".")
 | |
|         value = util.dot_to_object(config, path)
 | |
|         result[variable] = repr(value)
 | |
|     return result
 |