Improve error messages

This commit is contained in:
Ines Montani 2020-06-29 16:54:47 +02:00
parent 24664efa23
commit 7c08713baa

View File

@ -1,4 +1,4 @@
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Sequence
import typer
import srsly
from pathlib import Path
@ -372,8 +372,7 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
config_commands = config.get("commands", [])
commands = {cmd["name"]: cmd for cmd in config_commands}
if subcommand:
if subcommand not in commands:
msg.fail(f"Can't find command '{subcommand}' in project config", exits=1)
validate_subcommand(commands.keys(), subcommand)
print(f"Usage: {COMMAND} project run {project_dir} {subcommand}")
help_text = commands[subcommand].get("help")
if help_text:
@ -401,8 +400,7 @@ def project_run(project_dir: Path, subcommand: str, *dvc_args) -> None:
config_commands = config.get("commands", [])
variables = config.get("variables", {})
commands = {cmd["name"]: cmd for cmd in config_commands}
if subcommand not in commands:
msg.fail(f"Can't find command '{subcommand}' in project config", exits=1)
validate_subcommand(commands.keys(), subcommand)
if subcommand in config.get("run", []):
# This is one of the pipeline commands tracked in DVC
dvc_cmd = ["dvc", "repro", subcommand, *dvc_args]
@ -448,10 +446,14 @@ def load_project_config(path: Path) -> Dict[str, Any]:
config_path = path / CONFIG_FILE
if not config_path.exists():
msg.fail("Can't find project config", config_path, exits=1)
config = srsly.read_yaml(config_path)
invalid_err = f"Invalid project config in {CONFIG_FILE}"
try:
config = srsly.read_yaml(config_path)
except ValueError as e:
msg.fail(invalid_err, e, exits=1)
errors = validate(ProjectConfigSchema, config)
if errors:
msg.fail(f"Invalid project config in {CONFIG_FILE}", "\n".join(errors), exits=1)
msg.fail(invalid_err, "\n".join(errors), exits=1)
return config
@ -490,8 +492,7 @@ def update_dvc_config(
# commands in project.yml and should be run in sequence
config_commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
for name in config.get("run", []):
if name not in config_commands:
msg.fail(f"Can't find command '{name}' in project config", exits=1)
validate_subcommand(config_commands.keys(), name)
command = config_commands[name]
deps = command.get("deps", [])
outputs = command.get("outputs", [])
@ -634,6 +635,20 @@ def check_clone(name: str, dest: Path, repo: str) -> None:
)
def validate_subcommand(commands: Sequence[str], subcommand: str) -> None:
"""Check that a subcommand is valid and defined. Raises an error otherwise.
commands (Sequence[str]): The available commands.
subcommand (str): The subcommand.
"""
if subcommand not in commands:
msg.fail(
f"Can't find command '{subcommand}' in {CONFIG_FILE}. "
f"Available commands: {', '.join(commands)}",
exits=1,
)
def download_file(url: str, dest: Path, chunk_size: int = 1024) -> None:
"""Download a file using requests.