diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index df98e711f..ce9c2a5d3 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -1,5 +1,5 @@ from typing import Dict, Any, Union, List, Optional, Tuple, Iterable -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, overload, cast import sys import shutil from pathlib import Path @@ -151,6 +151,7 @@ def load_project_config( config = srsly.read_yaml(config_path) except ValueError as e: msg.fail(invalid_err, e, exits=1) + print(config) errors = validate(ProjectConfigSchema, config) if errors: msg.fail(invalid_err) @@ -221,18 +222,9 @@ def validate_project_commands(config: Dict[str, Any]) -> None: config (Dict[str, Any]): The loaded config. """ - command_names = [cmd["name"] for cmd in config.get("commands", [])] - workflows = config.get("workflows", {}) - duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1]) - if duplicates: - err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}" - msg.fail(err, exits=1) - for workflow_name, workflow_steps in workflows.items(): - if workflow_name in command_names: - err = f"Can't use workflow name '{workflow_name}': name already exists as a command" - msg.fail(err, exits=1) - for step in workflow_steps: - if step not in command_names: + + def verify_workflow_step(step: str): + if step not in command_names: msg.fail( f"Unknown command specified in workflow '{workflow_name}': {step}", f"Workflows can only refer to commands defined in the 'commands' " @@ -240,6 +232,30 @@ def validate_project_commands(config: Dict[str, Any]) -> None: exits=1, ) + command_names = [cmd["name"] for cmd in config.get("commands", [])] + workflows = config.get("workflows", {}) + duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1]) + if duplicates: + err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}" + msg.fail(err, exits=1) + for workflow_name, workflow_step_or_lists in workflows.items(): + if workflow_name in command_names: + err = f"Can't use workflow name '{workflow_name}': name already exists as a command" + msg.fail(err, exits=1) + for step_or_list in workflow_step_or_lists: + if isinstance(step_or_list, str): + verify_workflow_step(step_or_list) + else: + workflow_list = cast(List[str], step_or_list) + if len(workflow_list) < 2: + msg.fail( + f"Invalid multiprocessing group within '{workflow_name}'", + f"A multiprocessing group must reference at least two commands.", + exits=1, + ) + for step in workflow_list: + verify_workflow_step(step) + def get_hash(data, exclude: Iterable[str] = tuple()) -> str: """Get the hash for a JSON-serializable object. diff --git a/spacy/schemas.py b/spacy/schemas.py index 1dfd8ee85..5a9391640 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -458,8 +458,8 @@ class ProjectConfigSchema(BaseModel): vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands") env: Dict[StrictStr, Any] = Field({}, title="Optional variable names to substitute in commands, mapped to environment variable names") assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets") - workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order") - commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts") + workflows: Dict[StrictStr, List[Union[StrictStr, List[StrictStr]]]] = Field({}, title="Named workflows, mapped to list of project commands to run in order") + commands: List[ProjectConfigCommand] = Field([], title="Project command shortcuts") title: Optional[str] = Field(None, title="Project title") spacy_version: Optional[StrictStr] = Field(None, title="spaCy version range that the project is compatible with") # fmt: on diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 0fa6f5670..86583b1ab 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -421,6 +421,42 @@ def test_project_config_interpolation(int_value): substitute_project_variables(project) +def test_project_config_multiprocessing_good_case(): + project = { + "commands": [ + {"name": "command1", "script": ["echo", "command1"]}, + {"name": "command2", "script": ["echo", "command2"]}, + {"name": "command3", "script": ["echo", "command3"]}, + ], + "workflows": {"all": ["command1", ["command2", "command3"]]}, + } + with make_tempdir() as d: + srsly.write_yaml(d / "project.yml", project) + load_project_config(d) + + +@pytest.mark.parametrize( + "workflows", + [ + {"all": ["command1", ["command2"], "command3"]}, + {"all": ["command1", ["command2", "command4"]]}, + ], +) +def test_project_config_multiprocessing_bad_case(workflows): + project = { + "commands": [ + {"name": "command1", "script": ["echo", "command1"]}, + {"name": "command2", "script": ["echo", "command2"]}, + {"name": "command3", "script": ["echo", "command3"]}, + ], + "workflows": workflows, + } + with make_tempdir() as d: + srsly.write_yaml(d / "project.yml", project) + with pytest.raises(SystemExit): + load_project_config(d) + + @pytest.mark.parametrize( "greeting", [342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],