from typing import Any, Dict, Optional, List from pathlib import Path from wasabi import msg import os import re import shutil import requests from radicli import Arg, ExistingDirPath from ...util import ensure_path, working_dir from .._util import cli, PROJECT_FILE, load_project_config from .._util import get_checksum, download_file, git_checkout, get_git_version from .._util import SimpleFrozenDict, parse_config_overrides # Whether assets are extra if `extra` is not set. EXTRA_DEFAULT = False @cli.subcommand( "project", "assets", # fmt: off project_dir=Arg(help="Path to cloned project. Defaults to current working directory"), sparse_checkout=Arg("--sparse", "-S", help="Use sparse checkout for assets provided via Git, to only check out and clone the files needed. Requires Git v22.2+"), extra=Arg("--extra", "-e", help="Download all assets, including those marked as 'extra'"), # fmt: on ) def project_assets_cli( project_dir: ExistingDirPath = Path.cwd(), sparse_checkout: bool = False, extra: bool = False, _extra: List[str] = [], ): """Fetch project assets like datasets and pretrained weights. Assets are defined in the "assets" section of the project.yml. If a checksum is provided in the project.yml, the file is only downloaded if no local file with the same checksum exists. DOCS: https://spacy.io/api/cli#project-assets """ overrides = parse_config_overrides(_extra) project_assets( project_dir, overrides=overrides, sparse_checkout=sparse_checkout, extra=extra, ) def project_assets( project_dir: Path, *, overrides: Dict[str, Any] = SimpleFrozenDict(), sparse_checkout: bool = False, extra: bool = False, ) -> None: """Fetch assets for a project using DVC if possible. project_dir (Path): Path to project directory. sparse_checkout (bool): Use sparse checkout for assets provided via Git, to only check out and clone the files needed. extra (bool): Whether to download all assets, including those marked as 'extra'. """ project_path = ensure_path(project_dir) config = load_project_config(project_path, overrides=overrides) assets = [ asset for asset in config.get("assets", []) if extra or not asset.get("extra", EXTRA_DEFAULT) ] if not assets: msg.warn( f"No assets specified in {PROJECT_FILE} (if assets are marked as extra, download them with --extra)", exits=0, ) msg.info(f"Fetching {len(assets)} asset(s)") for asset in assets: dest = (project_dir / asset["dest"]).resolve() checksum = asset.get("checksum") if "git" in asset: git_err = ( f"Cloning spaCy project templates requires Git and the 'git' command. " f"Make sure it's installed and that the executable is available." ) get_git_version(error=git_err) if dest.exists(): # If there's already a file, check for checksum if checksum and checksum == get_checksum(dest): msg.good( f"Skipping download with matching checksum: {asset['dest']}" ) continue else: if dest.is_dir(): shutil.rmtree(dest) else: dest.unlink() if "repo" not in asset["git"] or asset["git"]["repo"] is None: msg.fail( "A git asset must include 'repo', the repository address.", exits=1 ) if "path" not in asset["git"] or asset["git"]["path"] is None: msg.fail( "A git asset must include 'path' - use \"\" to get the entire repository.", exits=1, ) git_checkout( asset["git"]["repo"], asset["git"]["path"], dest, branch=asset["git"].get("branch"), sparse=sparse_checkout, ) msg.good(f"Downloaded asset {dest}") else: url = asset.get("url") if not url: # project.yml defines asset without URL that the user has to place check_private_asset(dest, checksum) continue fetch_asset(project_path, url, dest, checksum) def check_private_asset(dest: Path, checksum: Optional[str] = None) -> None: """Check and validate assets without a URL (private assets that the user has to provide themselves) and give feedback about the checksum. dest (Path): Destination path of the asset. checksum (Optional[str]): Optional checksum of the expected file. """ if not Path(dest).exists(): err = f"No URL provided for asset. You need to add this file yourself: {dest}" msg.warn(err) else: if not checksum: msg.good(f"Asset already exists: {dest}") elif checksum == get_checksum(dest): msg.good(f"Asset exists with matching checksum: {dest}") else: msg.fail(f"Asset available but with incorrect checksum: {dest}") def fetch_asset( project_path: Path, url: str, dest: Path, checksum: Optional[str] = None ) -> None: """Fetch an asset from a given URL or path. If a checksum is provided and a local file exists, it's only re-downloaded if the checksum doesn't match. project_path (Path): Path to project directory. url (str): URL or path to asset. checksum (Optional[str]): Optional expected checksum of local file. RETURNS (Optional[Path]): The path to the fetched asset or None if fetching the asset failed. """ dest_path = (project_path / dest).resolve() if dest_path.exists(): # If there's already a file, check for checksum if checksum: if checksum == get_checksum(dest_path): msg.good(f"Skipping download with matching checksum: {dest}") return else: # If there's not a checksum, make sure the file is a possibly valid size if os.path.getsize(dest_path) == 0: msg.warn(f"Asset exists but with size of 0 bytes, deleting: {dest}") os.remove(dest_path) # We might as well support the user here and create parent directories in # case the asset dir isn't listed as a dir to create in the project.yml if not dest_path.parent.exists(): dest_path.parent.mkdir(parents=True) with working_dir(project_path): url = convert_asset_url(url) try: download_file(url, dest_path) msg.good(f"Downloaded asset {dest}") except requests.exceptions.RequestException as e: if Path(url).exists() and Path(url).is_file(): # If it's a local file, copy to destination shutil.copy(url, str(dest_path)) msg.good(f"Copied local asset {dest}") else: msg.fail(f"Download failed: {dest}", e) if checksum and checksum != get_checksum(dest_path): msg.fail(f"Checksum doesn't match value defined in {PROJECT_FILE}: {dest}") def convert_asset_url(url: str) -> str: """Check and convert the asset URL if needed. url (str): The asset URL. RETURNS (str): The converted URL. """ # If the asset URL is a regular GitHub URL it's likely a mistake if ( re.match(r"(http(s?)):\/\/github.com", url) and "releases/download" not in url and "/raw/" not in url ): converted = url.replace("github.com", "raw.githubusercontent.com") converted = re.sub(r"/(tree|blob)/", "/", converted) msg.warn( "Downloading from a regular GitHub URL. This will only download " "the source of the page, not the actual file. Converting the URL " "to a raw URL.", converted, ) return converted return url