diff --git a/spacy/cli/project.py b/spacy/cli/project.py index c02c1cf98..861a48f4b 100644 --- a/spacy/cli/project.py +++ b/spacy/cli/project.py @@ -295,15 +295,21 @@ def project_assets(project_dir: Path) -> None: msg.warn(f"No assets specified in {CONFIG_FILE}", exits=0) msg.info(f"Fetching {len(assets)} asset(s)") variables = config.get("variables", {}) + fetched_assets = [] for asset in assets: url = asset["url"].format(**variables) dest = asset["dest"].format(**variables) - fetch_asset(project_path, url, dest, asset.get("checksum")) + fetched_path = fetch_asset(project_path, url, dest, asset.get("checksum")) + if fetched_path: + fetched_assets.append(str(fetched_path)) + if fetched_assets: + with working_dir(project_path): + run_command(["dvc", "add", *fetched_assets, "--external"]) def fetch_asset( project_path: Path, url: str, dest: Path, checksum: Optional[str] = None -) -> None: +) -> Optional[Path]: """Fetch an asset from a given URL or path. Will try to import the file using DVC's import-url if possible (fully tracked and versioned) and falls back to get-url (versioned) and a non-DVC download if necessary. If a @@ -313,6 +319,8 @@ def fetch_asset( project_path (Path): Path to project directory. url (str): URL or path to asset. checksum (Optional[str]): Optional expected checksum of local file. + RETURNS (Optional[Path]): The path to the fetched asset or None if fetching + the asset failed. """ url = convert_asset_url(url) dest_path = (project_path / dest).resolve() @@ -321,8 +329,7 @@ def fetch_asset( # TODO: add support for caches (dvc import-url with local path) if checksum == get_checksum(dest_path): msg.good(f"Skipping download with matching checksum: {dest}") - return - dvc_add_cmd = ["dvc", "add", str(dest_path), "--external"] + return dest_path with working_dir(project_path): try: # If these fail, we don't want to output an error or info message. @@ -334,16 +341,16 @@ def fetch_asset( except subprocess.CalledProcessError: dvc_cmd = ["dvc", "get-url", url, str(dest_path)] print(subprocess.check_output(dvc_cmd, stderr=subprocess.DEVNULL)) - run_command(dvc_add_cmd) except subprocess.CalledProcessError: try: download_file(url, dest_path) except requests.exceptions.HTTPError as e: msg.fail(f"Download failed: {dest}", e) - run_command(dvc_add_cmd) + return None if checksum and checksum != get_checksum(dest_path): msg.warn(f"Checksum doesn't match value defined in {CONFIG_FILE}: {dest}") msg.good(f"Fetched asset {dest}") + return dest_path def project_run_all(project_dir: Path, *dvc_args) -> None: