Permit multiprocessing groups in YAML

This commit is contained in:
richardpaulhudson 2022-05-09 12:50:25 +02:00
parent e626df959f
commit 8d08a68174
3 changed files with 67 additions and 15 deletions

View File

@ -1,5 +1,5 @@
from typing import Dict, Any, Union, List, Optional, Tuple, Iterable
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, overload, cast
import sys
import shutil
from pathlib import Path
@ -151,6 +151,7 @@ def load_project_config(
config = srsly.read_yaml(config_path)
except ValueError as e:
msg.fail(invalid_err, e, exits=1)
print(config)
errors = validate(ProjectConfigSchema, config)
if errors:
msg.fail(invalid_err)
@ -221,18 +222,9 @@ def validate_project_commands(config: Dict[str, Any]) -> None:
config (Dict[str, Any]): The loaded config.
"""
command_names = [cmd["name"] for cmd in config.get("commands", [])]
workflows = config.get("workflows", {})
duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
if duplicates:
err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
msg.fail(err, exits=1)
for workflow_name, workflow_steps in workflows.items():
if workflow_name in command_names:
err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
msg.fail(err, exits=1)
for step in workflow_steps:
if step not in command_names:
def verify_workflow_step(step: str):
if step not in command_names:
msg.fail(
f"Unknown command specified in workflow '{workflow_name}': {step}",
f"Workflows can only refer to commands defined in the 'commands' "
@ -240,6 +232,30 @@ def validate_project_commands(config: Dict[str, Any]) -> None:
exits=1,
)
command_names = [cmd["name"] for cmd in config.get("commands", [])]
workflows = config.get("workflows", {})
duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
if duplicates:
err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
msg.fail(err, exits=1)
for workflow_name, workflow_step_or_lists in workflows.items():
if workflow_name in command_names:
err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
msg.fail(err, exits=1)
for step_or_list in workflow_step_or_lists:
if isinstance(step_or_list, str):
verify_workflow_step(step_or_list)
else:
workflow_list = cast(List[str], step_or_list)
if len(workflow_list) < 2:
msg.fail(
f"Invalid multiprocessing group within '{workflow_name}'",
f"A multiprocessing group must reference at least two commands.",
exits=1,
)
for step in workflow_list:
verify_workflow_step(step)
def get_hash(data, exclude: Iterable[str] = tuple()) -> str:
"""Get the hash for a JSON-serializable object.

View File

@ -458,8 +458,8 @@ class ProjectConfigSchema(BaseModel):
vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands")
env: Dict[StrictStr, Any] = Field({}, title="Optional variable names to substitute in commands, mapped to environment variable names")
assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets")
workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts")
workflows: Dict[StrictStr, List[Union[StrictStr, List[StrictStr]]]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
commands: List[ProjectConfigCommand] = Field([], title="Project command shortcuts")
title: Optional[str] = Field(None, title="Project title")
spacy_version: Optional[StrictStr] = Field(None, title="spaCy version range that the project is compatible with")
# fmt: on

View File

@ -421,6 +421,42 @@ def test_project_config_interpolation(int_value):
substitute_project_variables(project)
def test_project_config_multiprocessing_good_case():
project = {
"commands": [
{"name": "command1", "script": ["echo", "command1"]},
{"name": "command2", "script": ["echo", "command2"]},
{"name": "command3", "script": ["echo", "command3"]},
],
"workflows": {"all": ["command1", ["command2", "command3"]]},
}
with make_tempdir() as d:
srsly.write_yaml(d / "project.yml", project)
load_project_config(d)
@pytest.mark.parametrize(
"workflows",
[
{"all": ["command1", ["command2"], "command3"]},
{"all": ["command1", ["command2", "command4"]]},
],
)
def test_project_config_multiprocessing_bad_case(workflows):
project = {
"commands": [
{"name": "command1", "script": ["echo", "command1"]},
{"name": "command2", "script": ["echo", "command2"]},
{"name": "command3", "script": ["echo", "command3"]},
],
"workflows": workflows,
}
with make_tempdir() as d:
srsly.write_yaml(d / "project.yml", project)
with pytest.raises(SystemExit):
load_project_config(d)
@pytest.mark.parametrize(
"greeting",
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],