mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Update project CLI
This commit is contained in:
		
							parent
							
								
									e0c16c0577
								
							
						
					
					
						commit
						5ba1df5e78
					
				|  | @ -4,20 +4,43 @@ import srsly | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from wasabi import msg | from wasabi import msg | ||||||
| import shlex | import shlex | ||||||
|  | import os | ||||||
|  | import re | ||||||
| 
 | 
 | ||||||
| from ._app import app, Arg, Opt | from ._app import app, Arg, Opt | ||||||
| from .. import about | from .. import about | ||||||
| from ..schemas import ProjectConfigSchema, validate | from ..schemas import ProjectConfigSchema, validate | ||||||
| from ..util import run_command | from ..util import ensure_path, run_command | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| CONFIG_FILE = "project.yml" | CONFIG_FILE = "project.yml" | ||||||
| DIRS = ["assets", "configs", "packages", "metrics", "scripts", "notebooks", "training"] | DIRS = ["assets", "configs", "packages", "metrics", "scripts", "notebooks", "training"] | ||||||
| 
 | CACHES = [ | ||||||
|  |     Path.home() / ".torch", | ||||||
|  |     Path.home() / ".caches" / "torch", | ||||||
|  |     os.environ.get("TORCH_HOME"), | ||||||
|  |     Path.home() / ".keras", | ||||||
|  | ] | ||||||
| 
 | 
 | ||||||
| project_cli = typer.Typer(help="Command-line interface for spaCy projects") | project_cli = typer.Typer(help="Command-line interface for spaCy projects") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @project_cli.callback(invoke_without_command=True) | ||||||
|  | def callback(): | ||||||
|  |     # This runs before every project command and ensures DVC is installed | ||||||
|  |     # TODO: check for "dvc" command instead of Python library? | ||||||
|  |     try: | ||||||
|  |         import dvc  # noqa: F401 | ||||||
|  |     except ImportError: | ||||||
|  |         msg.fail( | ||||||
|  |             "spaCy projects require DVC (Data Version Control)", | ||||||
|  |             "You can install the Python package from pip (pip install dvc) or " | ||||||
|  |             "conda (conda install -c conda-forge dvc). For more details, see the " | ||||||
|  |             "documentation: https://dvc.org/doc/install", | ||||||
|  |             exits=1, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @project_cli.command("clone") | @project_cli.command("clone") | ||||||
| def project_clone_cli( | def project_clone_cli( | ||||||
|     # fmt: off |     # fmt: off | ||||||
|  | @ -27,7 +50,50 @@ def project_clone_cli( | ||||||
|     # fmt: on |     # fmt: on | ||||||
| ): | ): | ||||||
|     """Clone a project template from a repository.""" |     """Clone a project template from a repository.""" | ||||||
|     print("Cloning", repo) |     project_clone(name, dest, repo=repo) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def project_clone(name: str, dest: Path, repo: str = about.__projects__) -> None: | ||||||
|  |     dest = ensure_path(dest) | ||||||
|  |     if not dest or not dest.exists() or not dest.is_dir(): | ||||||
|  |         msg.fail("Not a valid directory to clone project", dest, exits=1) | ||||||
|  |     cmd = ["dvc", "get", repo, name, "-o", str(dest)] | ||||||
|  |     msg.info(" ".join(cmd)) | ||||||
|  |     run_command(cmd) | ||||||
|  |     msg.good(f"Cloned project '{name}' from {repo}") | ||||||
|  |     with msg.loading("Setting up directories..."): | ||||||
|  |         for sub_dir in DIRS: | ||||||
|  |             dir_path = dest / sub_dir | ||||||
|  |             if not dir_path.exists(): | ||||||
|  |                 dir_path.mkdir(parents=True) | ||||||
|  |     msg.good(f"Your project is now ready!", dest.resolve()) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @project_cli.command("get-assets") | ||||||
|  | def project_get_assets_cli( | ||||||
|  |     path: Path = Arg(..., help="Path to cloned project", exists=True, file_okay=False) | ||||||
|  | ): | ||||||
|  |     """Use Data Version Control to get the assets for the project.""" | ||||||
|  |     project_get_assets(path) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def project_get_assets(project_path: Path) -> None: | ||||||
|  |     project_path = ensure_path(project_path) | ||||||
|  |     config = load_project_config(project_path) | ||||||
|  |     assets = config.get("assets", {}) | ||||||
|  |     if not assets: | ||||||
|  |         msg.warn(f"No assets specified in {CONFIG_FILE}", exits=0) | ||||||
|  |     msg.info(f"Getting {len(assets)} asset(s)") | ||||||
|  |     variables = config.get("variables", {}) | ||||||
|  |     for asset in assets: | ||||||
|  |         url = asset["url"].format(**variables) | ||||||
|  |         dest = asset["dest"].format(**variables) | ||||||
|  |         dest_path = project_path / dest | ||||||
|  |         check_asset(url) | ||||||
|  |         cmd = ["dvc", "get-url", url, str(dest_path)] | ||||||
|  |         msg.info(" ".join(cmd)) | ||||||
|  |         run_command(cmd) | ||||||
|  |         msg.good(f"Got asset {dest}") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @project_cli.command("run") | @project_cli.command("run") | ||||||
|  | @ -76,14 +142,21 @@ def load_project_config(path: Path) -> Dict[str, Any]: | ||||||
|     return config |     return config | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_dirs(project_dir: Path) -> None: |  | ||||||
|     for subdir in DIRS: |  | ||||||
|         (project_dir / subdir).mkdir(parents=True) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def run_commands(commands: List[str] = tuple(), variables: Dict[str, str] = {}) -> None: | def run_commands(commands: List[str] = tuple(), variables: Dict[str, str] = {}) -> None: | ||||||
|     for command in commands: |     for command in commands: | ||||||
|         # Substitute variables, e.g. "./{NAME}.json" |         # Substitute variables, e.g. "./{NAME}.json" | ||||||
|         command = command.format(**variables) |         command = command.format(**variables) | ||||||
|         msg.info(command) |         msg.info(command) | ||||||
|         run_command(shlex.split(command)) |         run_command(shlex.split(command)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def check_asset(url: str) -> None: | ||||||
|  |     # If the asset URL is a regular GitHub URL it's likely a mistake | ||||||
|  |     # TODO: support loading from GitHub URLs? Automatically convert to raw? | ||||||
|  |     if re.match("(http(s?)):\/\/github.com", url): | ||||||
|  |         msg.warn( | ||||||
|  |             "Downloading from a regular GitHub URL. This will only download " | ||||||
|  |             "the source of the page, not the actual file. If you want to " | ||||||
|  |             "download the raw file, click on 'Download' on the GitHub page " | ||||||
|  |             "and copy the raw.githubusercontent.com URL instead." | ||||||
|  |         ) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user