diff --git a/spacy/cli/project.py b/spacy/cli/project.py index 5011a13f9..c02c1cf98 100644 --- a/spacy/cli/project.py +++ b/spacy/cli/project.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Sequence import typer import srsly from pathlib import Path @@ -372,8 +372,7 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None: config_commands = config.get("commands", []) commands = {cmd["name"]: cmd for cmd in config_commands} if subcommand: - if subcommand not in commands: - msg.fail(f"Can't find command '{subcommand}' in project config", exits=1) + validate_subcommand(commands.keys(), subcommand) print(f"Usage: {COMMAND} project run {project_dir} {subcommand}") help_text = commands[subcommand].get("help") if help_text: @@ -401,8 +400,7 @@ def project_run(project_dir: Path, subcommand: str, *dvc_args) -> None: config_commands = config.get("commands", []) variables = config.get("variables", {}) commands = {cmd["name"]: cmd for cmd in config_commands} - if subcommand not in commands: - msg.fail(f"Can't find command '{subcommand}' in project config", exits=1) + validate_subcommand(commands.keys(), subcommand) if subcommand in config.get("run", []): # This is one of the pipeline commands tracked in DVC dvc_cmd = ["dvc", "repro", subcommand, *dvc_args] @@ -448,10 +446,14 @@ def load_project_config(path: Path) -> Dict[str, Any]: config_path = path / CONFIG_FILE if not config_path.exists(): msg.fail("Can't find project config", config_path, exits=1) - config = srsly.read_yaml(config_path) + invalid_err = f"Invalid project config in {CONFIG_FILE}" + try: + config = srsly.read_yaml(config_path) + except ValueError as e: + msg.fail(invalid_err, e, exits=1) errors = validate(ProjectConfigSchema, config) if errors: - msg.fail(f"Invalid project config in {CONFIG_FILE}", "\n".join(errors), exits=1) + msg.fail(invalid_err, "\n".join(errors), exits=1) return config @@ -490,8 +492,7 @@ def update_dvc_config( # commands in project.yml and should be run in sequence config_commands = {cmd["name"]: cmd for cmd in config.get("commands", [])} for name in config.get("run", []): - if name not in config_commands: - msg.fail(f"Can't find command '{name}' in project config", exits=1) + validate_subcommand(config_commands.keys(), name) command = config_commands[name] deps = command.get("deps", []) outputs = command.get("outputs", []) @@ -634,6 +635,20 @@ def check_clone(name: str, dest: Path, repo: str) -> None: ) +def validate_subcommand(commands: Sequence[str], subcommand: str) -> None: + """Check that a subcommand is valid and defined. Raises an error otherwise. + + commands (Sequence[str]): The available commands. + subcommand (str): The subcommand. + """ + if subcommand not in commands: + msg.fail( + f"Can't find command '{subcommand}' in {CONFIG_FILE}. " + f"Available commands: {', '.join(commands)}", + exits=1, + ) + + def download_file(url: str, dest: Path, chunk_size: int = 1024) -> None: """Download a file using requests.