import hashlib import os import shutil import sys from configparser import InterpolationError from contextlib import contextmanager from pathlib import Path from typing import ( TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union, overload, ) import srsly import typer from click import NoSuchOption from click.parser import split_arg_string from thinc.api import Config, ConfigValidationError, require_gpu from thinc.util import gpu_is_available from typer.main import get_command from wasabi import Printer, msg from .. import about from ..errors import RENAMED_LANGUAGE_CODES from ..schemas import ProjectConfigSchema, validate from ..util import ( ENV_VARS, SimpleFrozenDict, import_file, is_compatible_version, logger, make_tempdir, registry, run_command, ) if TYPE_CHECKING: from pathy import FluidPath # noqa: F401 SDIST_SUFFIX = ".tar.gz" WHEEL_SUFFIX = "-py3-none-any.whl" PROJECT_FILE = "project.yml" PROJECT_LOCK = "project.lock" COMMAND = "python -m spacy" NAME = "spacy" HELP = """spaCy Command-line Interface DOCS: https://spacy.io/api/cli """ PROJECT_HELP = f"""Command-line interface for spaCy projects and templates. You'd typically start by cloning a project template to a local directory and fetching its assets like datasets etc. See the project's {PROJECT_FILE} for the available commands. """ DEBUG_HELP = """Suite of helpful commands for debugging and profiling. Includes commands to check and validate your config files, training and evaluation data, and custom model implementations. """ BENCHMARK_HELP = """Commands for benchmarking pipelines.""" INIT_HELP = """Commands for initializing configs and pipeline packages.""" # Wrappers for Typer's annotations. Initially created to set defaults and to # keep the names short, but not needed at the moment. Arg = typer.Argument Opt = typer.Option app = typer.Typer(name=NAME, help=HELP) benchmark_cli = typer.Typer(name="benchmark", help=BENCHMARK_HELP, no_args_is_help=True) project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True) debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True) init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True) app.add_typer(project_cli) app.add_typer(debug_cli) app.add_typer(benchmark_cli) app.add_typer(init_cli) def setup_cli() -> None: # Make sure the entry-point for CLI runs, so that they get imported. registry.cli.get_all() # Ensure that the help messages always display the correct prompt command = get_command(app) command(prog_name=COMMAND) def parse_config_overrides( args: List[str], env_var: Optional[str] = ENV_VARS.CONFIG_OVERRIDES ) -> Dict[str, Any]: """Generate a dictionary of config overrides based on the extra arguments provided on the CLI, e.g. --training.batch_size to override "training.batch_size". Arguments without a "." are considered invalid, since the config only allows top-level sections to exist. env_vars (Optional[str]): Optional environment variable to read from. RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting. """ env_string = os.environ.get(env_var, "") if env_var else "" env_overrides = _parse_overrides(split_arg_string(env_string)) cli_overrides = _parse_overrides(args, is_cli=True) if cli_overrides: keys = [k for k in cli_overrides if k not in env_overrides] logger.debug("Config overrides from CLI: %s", keys) if env_overrides: logger.debug("Config overrides from env variables: %s", list(env_overrides)) return {**cli_overrides, **env_overrides} def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]: result = {} while args: opt = args.pop(0) err = f"Invalid config override '{opt}'" if opt.startswith("--"): # new argument orig_opt = opt opt = opt.replace("--", "") if "." not in opt: if is_cli: raise NoSuchOption(orig_opt) else: msg.fail(f"{err}: can't override top-level sections", exits=1) if "=" in opt: # we have --opt=value opt, value = opt.split("=", 1) opt = opt.replace("-", "_") else: if not args or args[0].startswith("--"): # flag with no value value = "true" else: value = args.pop(0) result[opt] = _parse_override(value) else: msg.fail(f"{err}: name should start with --", exits=1) return result def _parse_override(value: Any) -> Any: # Just like we do in the config, we're calling json.loads on the # values. But since they come from the CLI, it'd be unintuitive to # explicitly mark strings with escaped quotes. So we're working # around that here by falling back to a string if parsing fails. # TODO: improve logic to handle simple types like list of strings? try: return srsly.json_loads(value) except ValueError: return str(value) def _handle_renamed_language_codes(lang: Optional[str]) -> None: # Throw error for renamed language codes in v4 if lang in RENAMED_LANGUAGE_CODES: msg.fail( title="Renamed language code", text=f"Language code '{lang}' was replaced with '{RENAMED_LANGUAGE_CODES[lang]}' in spaCy v4. Update the language code from '{lang}' to '{RENAMED_LANGUAGE_CODES[lang]}'.", exits=1, ) def load_project_config( path: Path, interpolate: bool = True, overrides: Dict[str, Any] = SimpleFrozenDict() ) -> Dict[str, Any]: """Load the project.yml file from a directory and validate it. Also make sure that all directories defined in the config exist. path (Path): The path to the project directory. interpolate (bool): Whether to substitute project variables. overrides (Dict[str, Any]): Optional config overrides. RETURNS (Dict[str, Any]): The loaded project.yml. """ config_path = path / PROJECT_FILE if not config_path.exists(): msg.fail(f"Can't find {PROJECT_FILE}", config_path, exits=1) invalid_err = f"Invalid {PROJECT_FILE}. Double-check that the YAML is correct." try: config = srsly.read_yaml(config_path) except ValueError as e: msg.fail(invalid_err, e, exits=1) errors = validate(ProjectConfigSchema, config) if errors: msg.fail(invalid_err) print("\n".join(errors)) sys.exit(1) validate_project_version(config) validate_project_commands(config) if interpolate: err = f"{PROJECT_FILE} validation error" with show_validation_error(title=err, hint_fill=False): config = substitute_project_variables(config, overrides) # Make sure directories defined in config exist for subdir in config.get("directories", []): dir_path = path / subdir if not dir_path.exists(): dir_path.mkdir(parents=True) return config def substitute_project_variables( config: Dict[str, Any], overrides: Dict[str, Any] = SimpleFrozenDict(), key: str = "vars", env_key: str = "env", ) -> Dict[str, Any]: """Interpolate variables in the project file using the config system. config (Dict[str, Any]): The project config. overrides (Dict[str, Any]): Optional config overrides. key (str): Key containing variables in project config. env_key (str): Key containing environment variable mapping in project config. RETURNS (Dict[str, Any]): The interpolated project config. """ config.setdefault(key, {}) config.setdefault(env_key, {}) # Substitute references to env vars with their values for config_var, env_var in config[env_key].items(): config[env_key][config_var] = _parse_override(os.environ.get(env_var, "")) # Need to put variables in the top scope again so we can have a top-level # section "project" (otherwise, a list of commands in the top scope wouldn't) # be allowed by Thinc's config system cfg = Config({"project": config, key: config[key], env_key: config[env_key]}) cfg = Config().from_str(cfg.to_str(), overrides=overrides) interpolated = cfg.interpolate() return dict(interpolated["project"]) def validate_project_version(config: Dict[str, Any]) -> None: """If the project defines a compatible spaCy version range, chec that it's compatible with the current version of spaCy. config (Dict[str, Any]): The loaded config. """ spacy_version = config.get("spacy_version", None) if spacy_version and not is_compatible_version(about.__version__, spacy_version): err = ( f"The {PROJECT_FILE} specifies a spaCy version range ({spacy_version}) " f"that's not compatible with the version of spaCy you're running " f"({about.__version__}). You can edit version requirement in the " f"{PROJECT_FILE} to load it, but the project may not run as expected." ) msg.fail(err, exits=1) def validate_project_commands(config: Dict[str, Any]) -> None: """Check that project commands and workflows are valid, don't contain duplicates, don't clash and only refer to commands that exist. config (Dict[str, Any]): The loaded config. """ command_names = [cmd["name"] for cmd in config.get("commands", [])] workflows = config.get("workflows", {}) duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1]) if duplicates: err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}" msg.fail(err, exits=1) for workflow_name, workflow_steps in workflows.items(): if workflow_name in command_names: err = f"Can't use workflow name '{workflow_name}': name already exists as a command" msg.fail(err, exits=1) for step in workflow_steps: if step not in command_names: msg.fail( f"Unknown command specified in workflow '{workflow_name}': {step}", f"Workflows can only refer to commands defined in the 'commands' " f"section of the {PROJECT_FILE}.", exits=1, ) def get_hash(data, exclude: Iterable[str] = tuple()) -> str: """Get the hash for a JSON-serializable object. data: The data to hash. exclude (Iterable[str]): Top-level keys to exclude if data is a dict. RETURNS (str): The hash. """ if isinstance(data, dict): data = {k: v for k, v in data.items() if k not in exclude} data_str = srsly.json_dumps(data, sort_keys=True).encode("utf8") return hashlib.md5(data_str).hexdigest() def get_checksum(path: Union[Path, str]) -> str: """Get the checksum for a file or directory given its file path. If a directory path is provided, this uses all files in that directory. path (Union[Path, str]): The file or directory path. RETURNS (str): The checksum. """ path = Path(path) if not (path.is_file() or path.is_dir()): msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1) if path.is_file(): return hashlib.md5(Path(path).read_bytes()).hexdigest() else: # TODO: this is currently pretty slow dir_checksum = hashlib.md5() for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()): dir_checksum.update(sub_file.read_bytes()) return dir_checksum.hexdigest() @contextmanager def show_validation_error( file_path: Optional[Union[str, Path]] = None, *, title: Optional[str] = None, desc: str = "", show_config: Optional[bool] = None, hint_fill: bool = True, ): """Helper to show custom config validation errors on the CLI. file_path (str / Path): Optional file path of config file, used in hints. title (str): Override title of custom formatted error. desc (str): Override description of custom formatted error. show_config (bool): Whether to output the config the error refers to. hint_fill (bool): Show hint about filling config. """ try: yield except ConfigValidationError as e: title = title if title is not None else e.title if e.desc: desc = f"{e.desc}" if not desc else f"{e.desc}\n\n{desc}" # Re-generate a new error object with overrides err = e.from_error(e, title="", desc=desc, show_config=show_config) msg.fail(title) print(err.text.strip()) if hint_fill and "value_error.missing" in err.error_types: config_path = ( file_path if file_path is not None and str(file_path) != "-" else "config.cfg" ) msg.text( "If your config contains missing values, you can run the 'init " "fill-config' command to fill in all the defaults, if possible:", spaced=True, ) print(f"{COMMAND} init fill-config {config_path} {config_path} \n") sys.exit(1) except InterpolationError as e: msg.fail("Config validation error", e, exits=1) def import_code_paths(code_paths: str) -> None: """Helper to import comma-separated list of code paths.""" code_paths = [Path(p.strip()) for p in string_to_list(code_paths)] for code_path in code_paths: import_code(code_path) def import_code(code_path: Optional[Union[Path, str]]) -> None: """Helper to import Python file provided in training commands / commands using the config. This makes custom registered functions available. """ if code_path is not None: if not Path(code_path).exists(): msg.fail("Path to Python code not found", code_path, exits=1) try: import_file("python_code", code_path) except Exception as e: msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1) def upload_file(src: Path, dest: Union[str, "FluidPath"]) -> None: """Upload a file. src (Path): The source path. url (str): The destination URL to upload to. """ import smart_open # Create parent directories for local paths if isinstance(dest, Path): if not dest.parent.exists(): dest.parent.mkdir(parents=True) dest = str(dest) with smart_open.open(dest, mode="wb") as output_file: with src.open(mode="rb") as input_file: output_file.write(input_file.read()) def download_file( src: Union[str, "FluidPath"], dest: Path, *, force: bool = False ) -> None: """Download a file using smart_open. url (str): The URL of the file. dest (Path): The destination path. force (bool): Whether to force download even if file exists. If False, the download will be skipped. """ import smart_open if dest.exists() and not force: return None src = str(src) with smart_open.open(src, mode="rb", compression="disable") as input_file: with dest.open(mode="wb") as output_file: shutil.copyfileobj(input_file, output_file) def ensure_pathy(path): """Temporary helper to prevent importing Pathy globally (which can cause slow and annoying Google Cloud warning).""" from pathy import Pathy # noqa: F811 return Pathy.fluid(path) def git_checkout( repo: str, subpath: str, dest: Path, *, branch: str = "master", sparse: bool = False ): git_version = get_git_version() if dest.exists(): msg.fail("Destination of checkout must not exist", exits=1) if not dest.parent.exists(): msg.fail("Parent of destination of checkout must exist", exits=1) if sparse and git_version >= (2, 22): return git_sparse_checkout(repo, subpath, dest, branch) elif sparse: # Only show warnings if the user explicitly wants sparse checkout but # the Git version doesn't support it err_old = ( f"You're running an old version of Git (v{git_version[0]}.{git_version[1]}) " f"that doesn't fully support sparse checkout yet." ) err_unk = "You're running an unknown version of Git, so sparse checkout has been disabled." msg.warn( f"{err_unk if git_version == (0, 0) else err_old} " f"This means that more files than necessary may be downloaded " f"temporarily. To only download the files needed, make sure " f"you're using Git v2.22 or above." ) with make_tempdir() as tmp_dir: cmd = f"git -C {tmp_dir} clone {repo} . -b {branch}" run_command(cmd, capture=True) # We need Path(name) to make sure we also support subdirectories try: source_path = tmp_dir / Path(subpath) if not is_subpath_of(tmp_dir, source_path): err = f"'{subpath}' is a path outside of the cloned repository." msg.fail(err, repo, exits=1) shutil.copytree(str(source_path), str(dest)) except FileNotFoundError: err = f"Can't clone {subpath}. Make sure the directory exists in the repo (branch '{branch}')" msg.fail(err, repo, exits=1) def git_sparse_checkout(repo, subpath, dest, branch): # We're using Git, partial clone and sparse checkout to # only clone the files we need # This ends up being RIDICULOUS. omg. # So, every tutorial and SO post talks about 'sparse checkout'...But they # go and *clone* the whole repo. Worthless. And cloning part of a repo # turns out to be completely broken. The only way to specify a "path" is.. # a path *on the server*? The contents of which, specifies the paths. Wat. # Obviously this is hopelessly broken and insecure, because you can query # arbitrary paths on the server! So nobody enables this. # What we have to do is disable *all* files. We could then just checkout # the path, and it'd "work", but be hopelessly slow...Because it goes and # transfers every missing object one-by-one. So the final piece is that we # need to use some weird git internals to fetch the missings in bulk, and # *that* we can do by path. # We're using Git and sparse checkout to only clone the files we need with make_tempdir() as tmp_dir: # This is the "clone, but don't download anything" part. cmd = ( f"git clone {repo} {tmp_dir} --no-checkout --depth 1 " f"-b {branch} --filter=blob:none" ) run_command(cmd) # Now we need to find the missing filenames for the subpath we want. # Looking for this 'rev-list' command in the git --help? Hah. cmd = f"git -C {tmp_dir} rev-list --objects --all --missing=print -- {subpath}" ret = run_command(cmd, capture=True) git_repo = _http_to_git(repo) # Now pass those missings into another bit of git internals missings = " ".join([x[1:] for x in ret.stdout.split() if x.startswith("?")]) if not missings: err = ( f"Could not find any relevant files for '{subpath}'. " f"Did you specify a correct and complete path within repo '{repo}' " f"and branch {branch}?" ) msg.fail(err, exits=1) cmd = f"git -C {tmp_dir} fetch-pack {git_repo} {missings}" run_command(cmd, capture=True) # And finally, we can checkout our subpath cmd = f"git -C {tmp_dir} checkout {branch} {subpath}" run_command(cmd, capture=True) # Get a subdirectory of the cloned path, if appropriate source_path = tmp_dir / Path(subpath) if not is_subpath_of(tmp_dir, source_path): err = f"'{subpath}' is a path outside of the cloned repository." msg.fail(err, repo, exits=1) shutil.move(str(source_path), str(dest)) def git_repo_branch_exists(repo: str, branch: str) -> bool: """Uses 'git ls-remote' to check if a repository and branch exists repo (str): URL to get repo. branch (str): Branch on repo to check. RETURNS (bool): True if repo:branch exists. """ get_git_version() cmd = f"git ls-remote {repo} {branch}" # We might be tempted to use `--exit-code` with `git ls-remote`, but # `run_command` handles the `returncode` for us, so we'll rely on # the fact that stdout returns '' if the requested branch doesn't exist ret = run_command(cmd, capture=True) exists = ret.stdout != "" return exists def get_git_version( error: str = "Could not run 'git'. Make sure it's installed and the executable is available.", ) -> Tuple[int, int]: """Get the version of git and raise an error if calling 'git --version' fails. error (str): The error message to show. RETURNS (Tuple[int, int]): The version as a (major, minor) tuple. Returns (0, 0) if the version couldn't be determined. """ try: ret = run_command("git --version", capture=True) except: raise RuntimeError(error) stdout = ret.stdout.strip() if not stdout or not stdout.startswith("git version"): return 0, 0 version = stdout[11:].strip().split(".") return int(version[0]), int(version[1]) def _http_to_git(repo: str) -> str: if repo.startswith("http://"): repo = repo.replace(r"http://", r"https://") if repo.startswith(r"https://"): repo = repo.replace("https://", "git@").replace("/", ":", 1) if repo.endswith("/"): repo = repo[:-1] repo = f"{repo}.git" return repo def is_subpath_of(parent, child): """ Check whether `child` is a path contained within `parent`. """ # Based on https://stackoverflow.com/a/37095733 . # In Python 3.9, the `Path.is_relative_to()` method will supplant this, so # we can stop using crusty old os.path functions. parent_realpath = os.path.realpath(parent) child_realpath = os.path.realpath(child) return os.path.commonpath([parent_realpath, child_realpath]) == parent_realpath @overload def string_to_list(value: str, intify: Literal[False] = ...) -> List[str]: ... @overload def string_to_list(value: str, intify: Literal[True]) -> List[int]: ... def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[int]]: """Parse a comma-separated string to a list and account for various formatting options. Mostly used to handle CLI arguments that take a list of comma-separated values. value (str): The value to parse. intify (bool): Whether to convert values to ints. RETURNS (Union[List[str], List[int]]): A list of strings or ints. """ if not value: return [] # type: ignore[return-value] if value.startswith("[") and value.endswith("]"): value = value[1:-1] result = [] for p in value.split(","): p = p.strip() if p.startswith("'") and p.endswith("'"): p = p[1:-1] if p.startswith('"') and p.endswith('"'): p = p[1:-1] p = p.strip() if intify: p = int(p) # type: ignore[assignment] result.append(p) return result def setup_gpu(use_gpu: int, silent=None) -> None: """Configure the GPU and log info.""" if silent is None: local_msg = Printer() else: local_msg = Printer(no_print=silent, pretty=not silent) if use_gpu >= 0: local_msg.info(f"Using GPU: {use_gpu}") require_gpu(use_gpu) else: local_msg.info("Using CPU") if gpu_is_available(): local_msg.info("To switch to GPU 0, use the option: --gpu-id 0") def walk_directory(path: Path, suffix: Optional[str] = None) -> List[Path]: """Given a directory and a suffix, recursively find all files matching the suffix. Directories or files with names beginning with a . are ignored, but hidden flags on filesystems are not checked. When provided with a suffix `None`, there is no suffix-based filtering.""" if not path.is_dir(): return [path] paths = [path] locs = [] seen = set() for path in paths: if str(path) in seen: continue seen.add(str(path)) if path.parts[-1].startswith("."): continue elif path.is_dir(): paths.extend(path.iterdir()) elif suffix is not None and not path.parts[-1].endswith(suffix): continue else: locs.append(path) # It's good to sort these, in case the ordering messes up cache. locs.sort() return locs def _format_number(number: Union[int, float], ndigits: int = 2) -> str: """Formats a number (float or int) rounding to `ndigits`, without truncating trailing 0s, as happens with `round(number, ndigits)`""" if isinstance(number, float): return f"{number:.{ndigits}f}" else: return str(number)