mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Permit multiprocessing groups in YAML
This commit is contained in:
parent
e626df959f
commit
8d08a68174
|
@ -1,5 +1,5 @@
|
|||
from typing import Dict, Any, Union, List, Optional, Tuple, Iterable
|
||||
from typing import TYPE_CHECKING, overload
|
||||
from typing import TYPE_CHECKING, overload, cast
|
||||
import sys
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
@ -151,6 +151,7 @@ def load_project_config(
|
|||
config = srsly.read_yaml(config_path)
|
||||
except ValueError as e:
|
||||
msg.fail(invalid_err, e, exits=1)
|
||||
print(config)
|
||||
errors = validate(ProjectConfigSchema, config)
|
||||
if errors:
|
||||
msg.fail(invalid_err)
|
||||
|
@ -221,18 +222,9 @@ def validate_project_commands(config: Dict[str, Any]) -> None:
|
|||
|
||||
config (Dict[str, Any]): The loaded config.
|
||||
"""
|
||||
command_names = [cmd["name"] for cmd in config.get("commands", [])]
|
||||
workflows = config.get("workflows", {})
|
||||
duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
|
||||
if duplicates:
|
||||
err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
|
||||
msg.fail(err, exits=1)
|
||||
for workflow_name, workflow_steps in workflows.items():
|
||||
if workflow_name in command_names:
|
||||
err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
|
||||
msg.fail(err, exits=1)
|
||||
for step in workflow_steps:
|
||||
if step not in command_names:
|
||||
|
||||
def verify_workflow_step(step: str):
|
||||
if step not in command_names:
|
||||
msg.fail(
|
||||
f"Unknown command specified in workflow '{workflow_name}': {step}",
|
||||
f"Workflows can only refer to commands defined in the 'commands' "
|
||||
|
@ -240,6 +232,30 @@ def validate_project_commands(config: Dict[str, Any]) -> None:
|
|||
exits=1,
|
||||
)
|
||||
|
||||
command_names = [cmd["name"] for cmd in config.get("commands", [])]
|
||||
workflows = config.get("workflows", {})
|
||||
duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
|
||||
if duplicates:
|
||||
err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
|
||||
msg.fail(err, exits=1)
|
||||
for workflow_name, workflow_step_or_lists in workflows.items():
|
||||
if workflow_name in command_names:
|
||||
err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
|
||||
msg.fail(err, exits=1)
|
||||
for step_or_list in workflow_step_or_lists:
|
||||
if isinstance(step_or_list, str):
|
||||
verify_workflow_step(step_or_list)
|
||||
else:
|
||||
workflow_list = cast(List[str], step_or_list)
|
||||
if len(workflow_list) < 2:
|
||||
msg.fail(
|
||||
f"Invalid multiprocessing group within '{workflow_name}'",
|
||||
f"A multiprocessing group must reference at least two commands.",
|
||||
exits=1,
|
||||
)
|
||||
for step in workflow_list:
|
||||
verify_workflow_step(step)
|
||||
|
||||
|
||||
def get_hash(data, exclude: Iterable[str] = tuple()) -> str:
|
||||
"""Get the hash for a JSON-serializable object.
|
||||
|
|
|
@ -458,8 +458,8 @@ class ProjectConfigSchema(BaseModel):
|
|||
vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands")
|
||||
env: Dict[StrictStr, Any] = Field({}, title="Optional variable names to substitute in commands, mapped to environment variable names")
|
||||
assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets")
|
||||
workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
|
||||
commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts")
|
||||
workflows: Dict[StrictStr, List[Union[StrictStr, List[StrictStr]]]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
|
||||
commands: List[ProjectConfigCommand] = Field([], title="Project command shortcuts")
|
||||
title: Optional[str] = Field(None, title="Project title")
|
||||
spacy_version: Optional[StrictStr] = Field(None, title="spaCy version range that the project is compatible with")
|
||||
# fmt: on
|
||||
|
|
|
@ -421,6 +421,42 @@ def test_project_config_interpolation(int_value):
|
|||
substitute_project_variables(project)
|
||||
|
||||
|
||||
def test_project_config_multiprocessing_good_case():
|
||||
project = {
|
||||
"commands": [
|
||||
{"name": "command1", "script": ["echo", "command1"]},
|
||||
{"name": "command2", "script": ["echo", "command2"]},
|
||||
{"name": "command3", "script": ["echo", "command3"]},
|
||||
],
|
||||
"workflows": {"all": ["command1", ["command2", "command3"]]},
|
||||
}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
load_project_config(d)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"workflows",
|
||||
[
|
||||
{"all": ["command1", ["command2"], "command3"]},
|
||||
{"all": ["command1", ["command2", "command4"]]},
|
||||
],
|
||||
)
|
||||
def test_project_config_multiprocessing_bad_case(workflows):
|
||||
project = {
|
||||
"commands": [
|
||||
{"name": "command1", "script": ["echo", "command1"]},
|
||||
{"name": "command2", "script": ["echo", "command2"]},
|
||||
{"name": "command3", "script": ["echo", "command3"]},
|
||||
],
|
||||
"workflows": workflows,
|
||||
}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
with pytest.raises(SystemExit):
|
||||
load_project_config(d)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"greeting",
|
||||
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],
|
||||
|
|
Loading…
Reference in New Issue
Block a user