from typing import Dict, Any, Union, List, Optional, TYPE_CHECKING import sys import shutil from pathlib import Path from wasabi import msg import srsly import hashlib import typer from typer.main import get_command from contextlib import contextmanager from thinc.config import Config, ConfigValidationError from configparser import InterpolationError from ..schemas import ProjectConfigSchema, validate from ..util import import_file, run_command, make_tempdir if TYPE_CHECKING: from pathy import Pathy # noqa: F401 PROJECT_FILE = "project.yml" PROJECT_LOCK = "project.lock" COMMAND = "python -m spacy" NAME = "spacy" HELP = """spaCy Command-line Interface DOCS: https://spacy.io/api/cli """ PROJECT_HELP = f"""Command-line interface for spaCy projects and templates. You'd typically start by cloning a project template to a local directory and fetching its assets like datasets etc. See the project's {PROJECT_FILE} for the available commands. """ DEBUG_HELP = """Suite of helpful commands for debugging and profiling. Includes commands to check and validate your config files, training and evaluation data, and custom model implementations. """ INIT_HELP = """Commands for initializing configs and models.""" # Wrappers for Typer's annotations. Initially created to set defaults and to # keep the names short, but not needed at the moment. Arg = typer.Argument Opt = typer.Option app = typer.Typer(name=NAME, help=HELP) project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True) debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True) init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True) app.add_typer(project_cli) app.add_typer(debug_cli) app.add_typer(init_cli) def setup_cli() -> None: # Ensure that the help messages always display the correct prompt command = get_command(app) command(prog_name=COMMAND) def parse_config_overrides(args: List[str]) -> Dict[str, Any]: """Generate a dictionary of config overrides based on the extra arguments provided on the CLI, e.g. --training.batch_size to override "training.batch_size". Arguments without a "." are considered invalid, since the config only allows top-level sections to exist. args (List[str]): The extra arguments from the command line. RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting. """ result = {} while args: opt = args.pop(0) err = f"Invalid CLI argument '{opt}'" if opt.startswith("--"): # new argument opt = opt.replace("--", "") if "." not in opt: msg.fail(f"{err}: can't override top-level section", exits=1) if "=" in opt: # we have --opt=value opt, value = opt.split("=", 1) opt = opt.replace("-", "_") else: if not args or args[0].startswith("--"): # flag with no value value = "true" else: value = args.pop(0) # Just like we do in the config, we're calling json.loads on the # values. But since they come from the CLI, it'd be unintuitive to # explicitly mark strings with escaped quotes. So we're working # around that here by falling back to a string if parsing fails. # TODO: improve logic to handle simple types like list of strings? try: result[opt] = srsly.json_loads(value) except ValueError: result[opt] = str(value) else: msg.fail(f"{err}: override option should start with --", exits=1) return result def load_project_config(path: Path, interpolate: bool = True) -> Dict[str, Any]: """Load the project.yml file from a directory and validate it. Also make sure that all directories defined in the config exist. path (Path): The path to the project directory. interpolate (bool): Whether to substitute project variables. RETURNS (Dict[str, Any]): The loaded project.yml. """ config_path = path / PROJECT_FILE if not config_path.exists(): msg.fail(f"Can't find {PROJECT_FILE}", config_path, exits=1) invalid_err = f"Invalid {PROJECT_FILE}. Double-check that the YAML is correct." try: config = srsly.read_yaml(config_path) except ValueError as e: msg.fail(invalid_err, e, exits=1) errors = validate(ProjectConfigSchema, config) if errors: msg.fail(invalid_err) print("\n".join(errors)) sys.exit(1) validate_project_commands(config) # Make sure directories defined in config exist for subdir in config.get("directories", []): dir_path = path / subdir if not dir_path.exists(): dir_path.mkdir(parents=True) if interpolate: err = "project.yml validation error" with show_validation_error(title=err, hint_fill=False): config = substitute_project_variables(config) return config def substitute_project_variables(config: Dict[str, Any], overrides: Dict = {}): key = "vars" config.setdefault(key, {}) config[key].update(overrides) # Need to put variables in the top scope again so we can have a top-level # section "project" (otherwise, a list of commands in the top scope wouldn't) # be allowed by Thinc's config system cfg = Config({"project": config, key: config[key]}) interpolated = cfg.interpolate() return dict(interpolated["project"]) def validate_project_commands(config: Dict[str, Any]) -> None: """Check that project commands and workflows are valid, don't contain duplicates, don't clash and only refer to commands that exist. config (Dict[str, Any]): The loaded config. """ command_names = [cmd["name"] for cmd in config.get("commands", [])] workflows = config.get("workflows", {}) duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1]) if duplicates: err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}" msg.fail(err, exits=1) for workflow_name, workflow_steps in workflows.items(): if workflow_name in command_names: err = f"Can't use workflow name '{workflow_name}': name already exists as a command" msg.fail(err, exits=1) for step in workflow_steps: if step not in command_names: msg.fail( f"Unknown command specified in workflow '{workflow_name}': {step}", f"Workflows can only refer to commands defined in the 'commands' " f"section of the {PROJECT_FILE}.", exits=1, ) def get_hash(data) -> str: """Get the hash for a JSON-serializable object. data: The data to hash. RETURNS (str): The hash. """ data_str = srsly.json_dumps(data, sort_keys=True).encode("utf8") return hashlib.md5(data_str).hexdigest() def get_checksum(path: Union[Path, str]) -> str: """Get the checksum for a file or directory given its file path. If a directory path is provided, this uses all files in that directory. path (Union[Path, str]): The file or directory path. RETURNS (str): The checksum. """ path = Path(path) if path.is_file(): return hashlib.md5(Path(path).read_bytes()).hexdigest() if path.is_dir(): # TODO: this is currently pretty slow dir_checksum = hashlib.md5() for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()): dir_checksum.update(sub_file.read_bytes()) return dir_checksum.hexdigest() msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1) @contextmanager def show_validation_error( file_path: Optional[Union[str, Path]] = None, *, title: str = "Config validation error", hint_fill: bool = True, ): """Helper to show custom config validation errors on the CLI. file_path (str / Path): Optional file path of config file, used in hints. title (str): Title of the custom formatted error. hint_fill (bool): Show hint about filling config. """ try: yield except (ConfigValidationError, InterpolationError) as e: msg.fail(title, spaced=True) # TODO: This is kinda hacky and we should probably provide a better # helper for this in Thinc err_text = str(e).replace("Config validation error", "").strip() print(err_text) if hint_fill and "field required" in err_text: config_path = file_path if file_path is not None else "config.cfg" msg.text( "If your config contains missing values, you can run the 'init " "fill-config' command to fill in all the defaults, if possible:", spaced=True, ) print(f"{COMMAND} init fill-config {config_path} --base {config_path}\n") sys.exit(1) def import_code(code_path: Optional[Union[Path, str]]) -> None: """Helper to import Python file provided in training commands / commands using the config. This makes custom registered functions available. """ if code_path is not None: if not Path(code_path).exists(): msg.fail("Path to Python code not found", code_path, exits=1) try: import_file("python_code", code_path) except Exception as e: msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1) def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]: """RETURNS (List[str]): All sourced components in the original config, e.g. {"source": "en_core_web_sm"}. If the config contains a key "factory", we assume it refers to a component factory. """ return [ name for name, cfg in config.get("components", {}).items() if "factory" not in cfg and "source" in cfg ] def upload_file(src: Path, dest: Union[str, "Pathy"]) -> None: """Upload a file. src (Path): The source path. url (str): The destination URL to upload to. """ import smart_open # This logic is pretty hacky. We'd like pathy to do this probably? if ":/" not in str(dest): # Local path with Path(dest).open(mode="wb") as output_file: with src.open(mode="rb") as input_file: output_file.write(input_file.read()) elif str(dest).startswith("http") or str(dest).startswith("https"): with smart_open.open(str(dest), mode="wb") as output_file: with src.open(mode="rb") as input_file: output_file.write(input_file.read()) else: dest = ensure_pathy(dest) with dest.open(mode="wb") as output_file: with src.open(mode="rb") as input_file: output_file.write(input_file.read()) def download_file(src: Union[str, "Pathy"], dest: Path, *, force: bool = False) -> None: """Download a file using smart_open. url (str): The URL of the file. dest (Path): The destination path. force (bool): Whether to force download even if file exists. If False, the download will be skipped. """ import smart_open # This logic is pretty hacky. We'd like pathy to do this probably? if dest.exists() and not force: return None if src.startswith("http"): with smart_open.open(src, mode="rb") as input_file: with dest.open(mode="wb") as output_file: output_file.write(input_file.read()) elif ":/" not in src: with open(src, mode="rb") as input_file: with dest.open(mode="wb") as output_file: output_file.write(input_file.read()) else: src = ensure_pathy(src) with src.open(mode="rb") as input_file: with dest.open(mode="wb") as output_file: output_file.write(input_file.read()) def ensure_pathy(path): """Temporary helper to prevent importing Pathy globally (which can cause slow and annoying Google Cloud warning).""" from pathy import Pathy # noqa: F811 return Pathy(path) def git_sparse_checkout( repo: str, subpath: str, dest: Path, *, branch: Optional[str] = None ): if dest.exists(): msg.fail("Destination of checkout must not exist", exits=1) if not dest.parent.exists(): msg.fail("Parent of destination of checkout must exist", exits=1) # We're using Git and sparse checkout to only clone the files we need with make_tempdir() as tmp_dir: cmd = ( f"git clone {repo} {tmp_dir} --no-checkout " "--depth 1 --config core.sparseCheckout=true" ) if branch is not None: cmd = f"{cmd} -b {branch}" run_command(cmd) with (tmp_dir / ".git" / "info" / "sparse-checkout").open("w") as f: f.write(subpath) run_command(["git", "-C", str(tmp_dir), "fetch"]) run_command(["git", "-C", str(tmp_dir), "checkout"]) # We need Path(name) to make sure we also support subdirectories shutil.move(str(tmp_dir / Path(subpath)), str(dest)) print(dest) print(list(dest.iterdir()))