mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			442 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			442 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Dict, Any, Union, List, Optional, Tuple, TYPE_CHECKING
 | |
| import sys
 | |
| import shutil
 | |
| from pathlib import Path
 | |
| from wasabi import msg
 | |
| import srsly
 | |
| import hashlib
 | |
| import typer
 | |
| import subprocess
 | |
| from click import NoSuchOption
 | |
| from typer.main import get_command
 | |
| from contextlib import contextmanager
 | |
| from thinc.config import Config, ConfigValidationError
 | |
| from configparser import InterpolationError
 | |
| 
 | |
| from ..schemas import ProjectConfigSchema, validate
 | |
| from ..util import import_file, run_command, make_tempdir, registry
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from pathy import Pathy  # noqa: F401
 | |
| 
 | |
| 
 | |
| PROJECT_FILE = "project.yml"
 | |
| PROJECT_LOCK = "project.lock"
 | |
| COMMAND = "python -m spacy"
 | |
| NAME = "spacy"
 | |
| HELP = """spaCy Command-line Interface
 | |
| 
 | |
| DOCS: https://nightly.spacy.io/api/cli
 | |
| """
 | |
| PROJECT_HELP = f"""Command-line interface for spaCy projects and templates.
 | |
| You'd typically start by cloning a project template to a local directory and
 | |
| fetching its assets like datasets etc. See the project's {PROJECT_FILE} for the
 | |
| available commands.
 | |
| """
 | |
| DEBUG_HELP = """Suite of helpful commands for debugging and profiling. Includes
 | |
| commands to check and validate your config files, training and evaluation data,
 | |
| and custom model implementations.
 | |
| """
 | |
| INIT_HELP = """Commands for initializing configs and pipeline packages."""
 | |
| 
 | |
| # Wrappers for Typer's annotations. Initially created to set defaults and to
 | |
| # keep the names short, but not needed at the moment.
 | |
| Arg = typer.Argument
 | |
| Opt = typer.Option
 | |
| 
 | |
| app = typer.Typer(name=NAME, help=HELP)
 | |
| project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
 | |
| debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
 | |
| init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True)
 | |
| 
 | |
| app.add_typer(project_cli)
 | |
| app.add_typer(debug_cli)
 | |
| app.add_typer(init_cli)
 | |
| 
 | |
| 
 | |
| def setup_cli() -> None:
 | |
|     # Make sure the entry-point for CLI runs, so that they get imported.
 | |
|     registry.cli.get_all()
 | |
|     # Ensure that the help messages always display the correct prompt
 | |
|     command = get_command(app)
 | |
|     command(prog_name=COMMAND)
 | |
| 
 | |
| 
 | |
| def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
 | |
|     """Generate a dictionary of config overrides based on the extra arguments
 | |
|     provided on the CLI, e.g. --training.batch_size to override
 | |
|     "training.batch_size". Arguments without a "." are considered invalid,
 | |
|     since the config only allows top-level sections to exist.
 | |
| 
 | |
|     args (List[str]): The extra arguments from the command line.
 | |
|     RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
 | |
|     """
 | |
|     result = {}
 | |
|     while args:
 | |
|         opt = args.pop(0)
 | |
|         err = f"Invalid CLI argument '{opt}'"
 | |
|         if opt.startswith("--"):  # new argument
 | |
|             orig_opt = opt
 | |
|             opt = opt.replace("--", "")
 | |
|             if "." not in opt:
 | |
|                 raise NoSuchOption(orig_opt)
 | |
|             if "=" in opt:  # we have --opt=value
 | |
|                 opt, value = opt.split("=", 1)
 | |
|                 opt = opt.replace("-", "_")
 | |
|             else:
 | |
|                 if not args or args[0].startswith("--"):  # flag with no value
 | |
|                     value = "true"
 | |
|                 else:
 | |
|                     value = args.pop(0)
 | |
|             # Just like we do in the config, we're calling json.loads on the
 | |
|             # values. But since they come from the CLI, it'd be unintuitive to
 | |
|             # explicitly mark strings with escaped quotes. So we're working
 | |
|             # around that here by falling back to a string if parsing fails.
 | |
|             # TODO: improve logic to handle simple types like list of strings?
 | |
|             try:
 | |
|                 result[opt] = srsly.json_loads(value)
 | |
|             except ValueError:
 | |
|                 result[opt] = str(value)
 | |
|         else:
 | |
|             msg.fail(f"{err}: override option should start with --", exits=1)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| def load_project_config(path: Path, interpolate: bool = True) -> Dict[str, Any]:
 | |
|     """Load the project.yml file from a directory and validate it. Also make
 | |
|     sure that all directories defined in the config exist.
 | |
| 
 | |
|     path (Path): The path to the project directory.
 | |
|     interpolate (bool): Whether to substitute project variables.
 | |
|     RETURNS (Dict[str, Any]): The loaded project.yml.
 | |
|     """
 | |
|     config_path = path / PROJECT_FILE
 | |
|     if not config_path.exists():
 | |
|         msg.fail(f"Can't find {PROJECT_FILE}", config_path, exits=1)
 | |
|     invalid_err = f"Invalid {PROJECT_FILE}. Double-check that the YAML is correct."
 | |
|     try:
 | |
|         config = srsly.read_yaml(config_path)
 | |
|     except ValueError as e:
 | |
|         msg.fail(invalid_err, e, exits=1)
 | |
|     errors = validate(ProjectConfigSchema, config)
 | |
|     if errors:
 | |
|         msg.fail(invalid_err)
 | |
|         print("\n".join(errors))
 | |
|         sys.exit(1)
 | |
|     validate_project_commands(config)
 | |
|     # Make sure directories defined in config exist
 | |
|     for subdir in config.get("directories", []):
 | |
|         dir_path = path / subdir
 | |
|         if not dir_path.exists():
 | |
|             dir_path.mkdir(parents=True)
 | |
|     if interpolate:
 | |
|         err = "project.yml validation error"
 | |
|         with show_validation_error(title=err, hint_fill=False):
 | |
|             config = substitute_project_variables(config)
 | |
|     return config
 | |
| 
 | |
| 
 | |
| def substitute_project_variables(config: Dict[str, Any], overrides: Dict = {}):
 | |
|     key = "vars"
 | |
|     config.setdefault(key, {})
 | |
|     config[key].update(overrides)
 | |
|     # Need to put variables in the top scope again so we can have a top-level
 | |
|     # section "project" (otherwise, a list of commands in the top scope wouldn't)
 | |
|     # be allowed by Thinc's config system
 | |
|     cfg = Config({"project": config, key: config[key]})
 | |
|     interpolated = cfg.interpolate()
 | |
|     return dict(interpolated["project"])
 | |
| 
 | |
| 
 | |
| def validate_project_commands(config: Dict[str, Any]) -> None:
 | |
|     """Check that project commands and workflows are valid, don't contain
 | |
|     duplicates, don't clash  and only refer to commands that exist.
 | |
| 
 | |
|     config (Dict[str, Any]): The loaded config.
 | |
|     """
 | |
|     command_names = [cmd["name"] for cmd in config.get("commands", [])]
 | |
|     workflows = config.get("workflows", {})
 | |
|     duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
 | |
|     if duplicates:
 | |
|         err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
 | |
|         msg.fail(err, exits=1)
 | |
|     for workflow_name, workflow_steps in workflows.items():
 | |
|         if workflow_name in command_names:
 | |
|             err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
 | |
|             msg.fail(err, exits=1)
 | |
|         for step in workflow_steps:
 | |
|             if step not in command_names:
 | |
|                 msg.fail(
 | |
|                     f"Unknown command specified in workflow '{workflow_name}': {step}",
 | |
|                     f"Workflows can only refer to commands defined in the 'commands' "
 | |
|                     f"section of the {PROJECT_FILE}.",
 | |
|                     exits=1,
 | |
|                 )
 | |
| 
 | |
| 
 | |
| def get_hash(data) -> str:
 | |
|     """Get the hash for a JSON-serializable object.
 | |
| 
 | |
|     data: The data to hash.
 | |
|     RETURNS (str): The hash.
 | |
|     """
 | |
|     data_str = srsly.json_dumps(data, sort_keys=True).encode("utf8")
 | |
|     return hashlib.md5(data_str).hexdigest()
 | |
| 
 | |
| 
 | |
| def get_checksum(path: Union[Path, str]) -> str:
 | |
|     """Get the checksum for a file or directory given its file path. If a
 | |
|     directory path is provided, this uses all files in that directory.
 | |
| 
 | |
|     path (Union[Path, str]): The file or directory path.
 | |
|     RETURNS (str): The checksum.
 | |
|     """
 | |
|     path = Path(path)
 | |
|     if path.is_file():
 | |
|         return hashlib.md5(Path(path).read_bytes()).hexdigest()
 | |
|     if path.is_dir():
 | |
|         # TODO: this is currently pretty slow
 | |
|         dir_checksum = hashlib.md5()
 | |
|         for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
 | |
|             dir_checksum.update(sub_file.read_bytes())
 | |
|         return dir_checksum.hexdigest()
 | |
|     msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1)
 | |
| 
 | |
| 
 | |
| @contextmanager
 | |
| def show_validation_error(
 | |
|     file_path: Optional[Union[str, Path]] = None,
 | |
|     *,
 | |
|     title: str = "Config validation error",
 | |
|     hint_fill: bool = True,
 | |
| ):
 | |
|     """Helper to show custom config validation errors on the CLI.
 | |
| 
 | |
|     file_path (str / Path): Optional file path of config file, used in hints.
 | |
|     title (str): Title of the custom formatted error.
 | |
|     hint_fill (bool): Show hint about filling config.
 | |
|     """
 | |
|     try:
 | |
|         yield
 | |
|     except (ConfigValidationError, InterpolationError) as e:
 | |
|         msg.fail(title, spaced=True)
 | |
|         # TODO: This is kinda hacky and we should probably provide a better
 | |
|         # helper for this in Thinc
 | |
|         err_text = str(e).replace("Config validation error", "").strip()
 | |
|         print(err_text)
 | |
|         if hint_fill and "field required" in err_text:
 | |
|             config_path = file_path if file_path is not None else "config.cfg"
 | |
|             msg.text(
 | |
|                 "If your config contains missing values, you can run the 'init "
 | |
|                 "fill-config' command to fill in all the defaults, if possible:",
 | |
|                 spaced=True,
 | |
|             )
 | |
|             print(f"{COMMAND} init fill-config {config_path} --base {config_path}\n")
 | |
|         sys.exit(1)
 | |
| 
 | |
| 
 | |
| def import_code(code_path: Optional[Union[Path, str]]) -> None:
 | |
|     """Helper to import Python file provided in training commands / commands
 | |
|     using the config. This makes custom registered functions available.
 | |
|     """
 | |
|     if code_path is not None:
 | |
|         if not Path(code_path).exists():
 | |
|             msg.fail("Path to Python code not found", code_path, exits=1)
 | |
|         try:
 | |
|             import_file("python_code", code_path)
 | |
|         except Exception as e:
 | |
|             msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
 | |
| 
 | |
| 
 | |
| def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
 | |
|     """RETURNS (List[str]): All sourced components in the original config,
 | |
|         e.g. {"source": "en_core_web_sm"}. If the config contains a key
 | |
|         "factory", we assume it refers to a component factory.
 | |
|     """
 | |
|     return [
 | |
|         name
 | |
|         for name, cfg in config.get("components", {}).items()
 | |
|         if "factory" not in cfg and "source" in cfg
 | |
|     ]
 | |
| 
 | |
| 
 | |
| def upload_file(src: Path, dest: Union[str, "Pathy"]) -> None:
 | |
|     """Upload a file.
 | |
| 
 | |
|     src (Path): The source path.
 | |
|     url (str): The destination URL to upload to.
 | |
|     """
 | |
|     import smart_open
 | |
| 
 | |
|     dest = str(dest)
 | |
|     with smart_open.open(dest, mode="wb") as output_file:
 | |
|         with src.open(mode="rb") as input_file:
 | |
|             output_file.write(input_file.read())
 | |
| 
 | |
| 
 | |
| def download_file(src: Union[str, "Pathy"], dest: Path, *, force: bool = False) -> None:
 | |
|     """Download a file using smart_open.
 | |
| 
 | |
|     url (str): The URL of the file.
 | |
|     dest (Path): The destination path.
 | |
|     force (bool): Whether to force download even if file exists.
 | |
|         If False, the download will be skipped.
 | |
|     """
 | |
|     import smart_open
 | |
| 
 | |
|     if dest.exists() and not force:
 | |
|         return None
 | |
|     src = str(src)
 | |
|     with smart_open.open(src, mode="rb") as input_file:
 | |
|         with dest.open(mode="wb") as output_file:
 | |
|             output_file.write(input_file.read())
 | |
| 
 | |
| 
 | |
| def ensure_pathy(path):
 | |
|     """Temporary helper to prevent importing Pathy globally (which can cause
 | |
|     slow and annoying Google Cloud warning)."""
 | |
|     from pathy import Pathy  # noqa: F811
 | |
| 
 | |
|     return Pathy(path)
 | |
| 
 | |
| 
 | |
| def git_checkout(
 | |
|     repo: str, subpath: str, dest: Path, *, branch: str = "master", sparse: bool = False
 | |
| ):
 | |
|     git_version = get_git_version()
 | |
|     if dest.exists():
 | |
|         msg.fail("Destination of checkout must not exist", exits=1)
 | |
|     if not dest.parent.exists():
 | |
|         raise IOError("Parent of destination of checkout must exist")
 | |
|     # We're using Git, partial clone and sparse checkout to
 | |
|     # only clone the files we need
 | |
|     # This ends up being RIDICULOUS. omg.
 | |
|     # So, every tutorial and SO post talks about 'sparse checkout'...But they
 | |
|     # go and *clone* the whole repo. Worthless. And cloning part of a repo
 | |
|     # turns out to be completely broken. The only way to specify a "path" is..
 | |
|     # a path *on the server*? The contents of which, specifies the paths. Wat.
 | |
|     # Obviously this is hopelessly broken and insecure, because you can query
 | |
|     # arbitrary paths on the server! So nobody enables this.
 | |
|     # What we have to do is disable *all* files. We could then just checkout
 | |
|     # the path, and it'd "work", but be hopelessly slow...Because it goes and
 | |
|     # transfers every missing object one-by-one. So the final piece is that we
 | |
|     # need to use some weird git internals to fetch the missings in bulk, and
 | |
|     # *that* we can do by path.
 | |
|     # We're using Git and sparse checkout to only clone the files we need
 | |
|     with make_tempdir() as tmp_dir:
 | |
|         supports_sparse = git_version >= (2, 22)
 | |
|         use_sparse = supports_sparse and sparse
 | |
|         # This is the "clone, but don't download anything" part.
 | |
|         cmd = f"git clone {repo} {tmp_dir} --no-checkout --depth 1 " f"-b {branch} "
 | |
|         if use_sparse:
 | |
|             cmd += f"--filter=blob:none"  # <-- The key bit
 | |
|         # Only show warnings if the user explicitly wants sparse checkout but
 | |
|         # the Git version doesn't support it
 | |
|         elif sparse:
 | |
|             err_old = (
 | |
|                 f"You're running an old version of Git (v{git_version[0]}.{git_version[1]}) "
 | |
|                 f"that doesn't fully support sparse checkout yet."
 | |
|             )
 | |
|             err_unk = "You're running an unknown version of Git, so sparse checkout has been disabled."
 | |
|             msg.warn(
 | |
|                 f"{err_unk if git_version == (0, 0) else err_old} "
 | |
|                 f"This means that more files than necessary may be downloaded "
 | |
|                 f"temporarily. To only download the files needed, make sure "
 | |
|                 f"you're using Git v2.22 or above."
 | |
|             )
 | |
|         try_run_command(cmd)
 | |
|         # Now we need to find the missing filenames for the subpath we want.
 | |
|         # Looking for this 'rev-list' command in the git --help? Hah.
 | |
|         cmd = f"git -C {tmp_dir} rev-list --objects --all {'--missing=print ' if use_sparse else ''} -- {subpath}"
 | |
|         ret = try_run_command(cmd)
 | |
|         git_repo = _from_http_to_git(repo)
 | |
|         # Now pass those missings into another bit of git internals
 | |
|         missings = " ".join([x[1:] for x in ret.stdout.split() if x.startswith("?")])
 | |
|         if use_sparse and not missings:
 | |
|             err = (
 | |
|                 f"Could not find any relevant files for '{subpath}'. "
 | |
|                 f"Did you specify a correct and complete path within repo '{repo}' "
 | |
|                 f"and branch {branch}?"
 | |
|             )
 | |
|             msg.fail(err, exits=1)
 | |
|         if use_sparse:
 | |
|             cmd = f"git -C {tmp_dir} fetch-pack {git_repo} {missings}"
 | |
|             try_run_command(cmd)
 | |
|         # And finally, we can checkout our subpath
 | |
|         cmd = f"git -C {tmp_dir} checkout {branch} {subpath}"
 | |
|         try_run_command(cmd)
 | |
|         # We need Path(name) to make sure we also support subdirectories
 | |
|         shutil.move(str(tmp_dir / Path(subpath)), str(dest))
 | |
| 
 | |
| 
 | |
| def get_git_version(
 | |
|     error: str = "Could not run 'git'. Make sure it's installed and the executable is available.",
 | |
| ) -> Tuple[int, int]:
 | |
|     """Get the version of git and raise an error if calling 'git --version' fails.
 | |
| 
 | |
|     error (str): The error message to show.
 | |
|     RETURNS (Tuple[int, int]): The version as a (major, minor) tuple. Returns
 | |
|         (0, 0) if the version couldn't be determined.
 | |
|     """
 | |
|     ret = try_run_command(["git", "--version"], error=error)
 | |
|     stdout = ret.stdout.strip()
 | |
|     if not stdout or not stdout.startswith("git version"):
 | |
|         return (0, 0)
 | |
|     version = stdout[11:].strip().split(".")
 | |
|     return (int(version[0]), int(version[1]))
 | |
| 
 | |
| 
 | |
| def try_run_command(
 | |
|     cmd: Union[str, List[str]], error: str = "Could not run command"
 | |
| ) -> subprocess.CompletedProcess:
 | |
|     """Try running a command and raise an error if it fails.
 | |
| 
 | |
|     cmd (Union[str, List[str]]): The command to run.
 | |
|     error (str): The error message.
 | |
|     RETURNS (CompletedProcess): The completed process if the command ran.
 | |
|     """
 | |
|     try:
 | |
|         return run_command(cmd, capture=True)
 | |
|     except subprocess.CalledProcessError as e:
 | |
|         msg.fail(error)
 | |
|         print(cmd)
 | |
|         sys.exit(1)
 | |
| 
 | |
| 
 | |
| def _from_http_to_git(repo: str) -> str:
 | |
|     if repo.startswith("http://"):
 | |
|         repo = repo.replace(r"http://", r"https://")
 | |
|     if repo.startswith(r"https://"):
 | |
|         repo = repo.replace("https://", "git@").replace("/", ":", 1)
 | |
|         if repo.endswith("/"):
 | |
|             repo = repo[:-1]
 | |
|         repo = f"{repo}.git"
 | |
|     return repo
 | |
| 
 | |
| 
 | |
| def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[int]]:
 | |
|     """Parse a comma-separated string to a list and account for various
 | |
|     formatting options. Mostly used to handle CLI arguments that take a list of
 | |
|     comma-separated values.
 | |
| 
 | |
|     value (str): The value to parse.
 | |
|     intify (bool): Whether to convert values to ints.
 | |
|     RETURNS (Union[List[str], List[int]]): A list of strings or ints.
 | |
|     """
 | |
|     if not value:
 | |
|         return []
 | |
|     if value.startswith("[") and value.endswith("]"):
 | |
|         value = value[1:-1]
 | |
|     result = []
 | |
|     for p in value.split(","):
 | |
|         p = p.strip()
 | |
|         if p.startswith("'") and p.endswith("'"):
 | |
|             p = p[1:-1]
 | |
|         if p.startswith('"') and p.endswith('"'):
 | |
|             p = p[1:-1]
 | |
|         p = p.strip()
 | |
|         if intify:
 | |
|             p = int(p)
 | |
|         result.append(p)
 | |
|     return result
 |