diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index 12db1e75b..b0e6dd4d6 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -215,6 +215,16 @@ def validate_project_version(config: Dict[str, Any]) -> None: msg.fail(err, exits=1) +def verify_workflow_step(workflow_name: str, commands: List[str], step: str) -> None: + if step not in commands: + msg.fail( + f"Unknown command specified in workflow '{workflow_name}': {step}", + f"Workflows can only refer to commands defined in the 'commands' " + f"section of the {PROJECT_FILE}.", + exits=1, + ) + + def validate_project_commands(config: Dict[str, Any]) -> None: """Check that project commands and workflows are valid, don't contain duplicates, don't clash and only refer to commands that exist. @@ -222,44 +232,37 @@ def validate_project_commands(config: Dict[str, Any]) -> None: config (Dict[str, Any]): The loaded config. """ - def verify_workflow_step(workflow_name: str, step: str) -> None: - if step not in command_names: - msg.fail( - f"Unknown command specified in workflow '{workflow_name}': {step}", - f"Workflows can only refer to commands defined in the 'commands' " - f"section of the {PROJECT_FILE}.", - exits=1, - ) - - command_names = [cmd["name"] for cmd in config.get("commands", [])] + commands = [cmd["name"] for cmd in config.get("commands", [])] workflows = config.get("workflows", {}) - duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1]) + duplicates = set([cmd for cmd in commands if commands.count(cmd) > 1]) if duplicates: err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}" msg.fail(err, exits=1) - for workflow_name, workflow_step_or_lists in workflows.items(): - if workflow_name in command_names: + for workflow_name, workflow_items in workflows.items(): + if workflow_name in commands: err = f"Can't use workflow name '{workflow_name}': name already exists as a command" msg.fail(err, exits=1) - for step_or_list in workflow_step_or_lists: + for step_or_list in workflow_items: if isinstance(step_or_list, str): - verify_workflow_step(workflow_name, step_or_list) + verify_workflow_step(workflow_name, commands, step_or_list) else: - workflow_list = cast(List[str], step_or_list) - if len(workflow_list) < 2: + assert isinstance(step_or_list, list) + assert isinstance(step_or_list[0], str) + steps = cast(List[str], step_or_list) + if len(steps) < 2: msg.fail( f"Invalid multiprocessing group within '{workflow_name}'.", f"A multiprocessing group must reference at least two commands.", exits=1, ) - if len(workflow_list) != len(set(workflow_list)): + if len(steps) != len(set(steps)): msg.fail( f"A multiprocessing group within '{workflow_name}' contains a command more than once.", f"This is not permitted because it is then not possible to determine when to rerun.", exits=1, ) - for step in workflow_list: - verify_workflow_step(workflow_name, step) + for step in steps: + verify_workflow_step(workflow_name, commands, step) def get_hash(data, exclude: Iterable[str] = tuple()) -> str: diff --git a/spacy/cli/project/dvc.py b/spacy/cli/project/dvc.py index b03a95635..3069b01ea 100644 --- a/spacy/cli/project/dvc.py +++ b/spacy/cli/project/dvc.py @@ -110,6 +110,8 @@ def update_dvc_config( if isinstance(cmdOrMultiprocessingGroup, str): names.append(cmdOrMultiprocessingGroup) else: + assert isinstance(cmdOrMultiprocessingGroup, list) + assert isinstance(cmdOrMultiprocessingGroup[0], str) names.extend(cmdOrMultiprocessingGroup) for name in names: command = config_commands[name] diff --git a/spacy/cli/project/run.py b/spacy/cli/project/run.py index 8f04006b5..8486a002e 100644 --- a/spacy/cli/project/run.py +++ b/spacy/cli/project/run.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict, Sequence, Any, Iterable, cast +from typing import Optional, List, Dict, Sequence, Any, Iterable, Union, Tuple from pathlib import Path from multiprocessing import Process, Lock from multiprocessing.synchronize import Lock as Lock_t @@ -133,6 +133,18 @@ def project_run( update_lockfile(current_dir, cmd, mult_group_mutex=mult_group_mutex) +def _get_workflow_steps(workflow_items: List[Union[str, List[str]]]) -> List[str]: + steps: List[str] = [] + for workflow_item in workflow_items: + if isinstance(workflow_item, str): + steps.append(workflow_item) + else: + assert isinstance(workflow_item, list) + assert isinstance(workflow_item[0], str) + steps.extend(workflow_item) + return steps + + def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None: """Simulate a CLI help prompt using the info available in the project.yml. @@ -154,12 +166,7 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None: if help_text: print(f"\n{help_text}\n") elif subcommand in workflows: - steps = [] - for cmdOrMultiprocessingGroup in workflows[subcommand]: - if isinstance(cmdOrMultiprocessingGroup, str): - steps.append(cmdOrMultiprocessingGroup) - else: - steps.extend(cmdOrMultiprocessingGroup) + steps = _get_workflow_steps(workflows[subcommand]) print(f"\nWorkflow consisting of {len(steps)} commands:") steps_data = [ (f"{i + 1}. {step}", commands[step].get("help", "")) @@ -180,7 +187,12 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None: if workflows: print(f"Available workflows in {PROJECT_FILE}") print(f"Usage: {COMMAND} project run [WORKFLOW] {project_loc}") - msg.table([(name, " -> ".join(steps)) for name, steps in workflows.items()]) + msg.table( + [ + (name, " -> ".join(_get_workflow_steps(workflow_items))) + for name, workflow_items in workflows.items() + ] + ) def run_commands( diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 310bc97b8..e10486245 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -492,14 +492,14 @@ else: # should never happen because of skipping "script": [ " ".join(("python", pscript_loc, str(d), "a", "b", "c", "d")) ], - "outputs": [" ".join((str(d), "c"))], + "outputs": [os.sep.join((str(d), "f"))], }, { "name": "commandB", "script": [ " ".join(("python", pscript_loc, str(d), "b", "a", "e", "f")) ], - "outputs": [" ".join((str(d), "e"))], + "outputs": [os.sep.join((str(d), "e"))], }, ], "workflows": {"all": [["commandA", "commandB"], ["commandA", "commandB"]]},