mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-13 16:44:56 +03:00
Permit multiprocessing groups in YAML
This commit is contained in:
parent
e626df959f
commit
8d08a68174
|
@ -1,5 +1,5 @@
|
||||||
from typing import Dict, Any, Union, List, Optional, Tuple, Iterable
|
from typing import Dict, Any, Union, List, Optional, Tuple, Iterable
|
||||||
from typing import TYPE_CHECKING, overload
|
from typing import TYPE_CHECKING, overload, cast
|
||||||
import sys
|
import sys
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -151,6 +151,7 @@ def load_project_config(
|
||||||
config = srsly.read_yaml(config_path)
|
config = srsly.read_yaml(config_path)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
msg.fail(invalid_err, e, exits=1)
|
msg.fail(invalid_err, e, exits=1)
|
||||||
|
print(config)
|
||||||
errors = validate(ProjectConfigSchema, config)
|
errors = validate(ProjectConfigSchema, config)
|
||||||
if errors:
|
if errors:
|
||||||
msg.fail(invalid_err)
|
msg.fail(invalid_err)
|
||||||
|
@ -221,17 +222,8 @@ def validate_project_commands(config: Dict[str, Any]) -> None:
|
||||||
|
|
||||||
config (Dict[str, Any]): The loaded config.
|
config (Dict[str, Any]): The loaded config.
|
||||||
"""
|
"""
|
||||||
command_names = [cmd["name"] for cmd in config.get("commands", [])]
|
|
||||||
workflows = config.get("workflows", {})
|
def verify_workflow_step(step: str):
|
||||||
duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
|
|
||||||
if duplicates:
|
|
||||||
err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
|
|
||||||
msg.fail(err, exits=1)
|
|
||||||
for workflow_name, workflow_steps in workflows.items():
|
|
||||||
if workflow_name in command_names:
|
|
||||||
err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
|
|
||||||
msg.fail(err, exits=1)
|
|
||||||
for step in workflow_steps:
|
|
||||||
if step not in command_names:
|
if step not in command_names:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"Unknown command specified in workflow '{workflow_name}': {step}",
|
f"Unknown command specified in workflow '{workflow_name}': {step}",
|
||||||
|
@ -240,6 +232,30 @@ def validate_project_commands(config: Dict[str, Any]) -> None:
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
command_names = [cmd["name"] for cmd in config.get("commands", [])]
|
||||||
|
workflows = config.get("workflows", {})
|
||||||
|
duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
|
||||||
|
if duplicates:
|
||||||
|
err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
|
||||||
|
msg.fail(err, exits=1)
|
||||||
|
for workflow_name, workflow_step_or_lists in workflows.items():
|
||||||
|
if workflow_name in command_names:
|
||||||
|
err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
|
||||||
|
msg.fail(err, exits=1)
|
||||||
|
for step_or_list in workflow_step_or_lists:
|
||||||
|
if isinstance(step_or_list, str):
|
||||||
|
verify_workflow_step(step_or_list)
|
||||||
|
else:
|
||||||
|
workflow_list = cast(List[str], step_or_list)
|
||||||
|
if len(workflow_list) < 2:
|
||||||
|
msg.fail(
|
||||||
|
f"Invalid multiprocessing group within '{workflow_name}'",
|
||||||
|
f"A multiprocessing group must reference at least two commands.",
|
||||||
|
exits=1,
|
||||||
|
)
|
||||||
|
for step in workflow_list:
|
||||||
|
verify_workflow_step(step)
|
||||||
|
|
||||||
|
|
||||||
def get_hash(data, exclude: Iterable[str] = tuple()) -> str:
|
def get_hash(data, exclude: Iterable[str] = tuple()) -> str:
|
||||||
"""Get the hash for a JSON-serializable object.
|
"""Get the hash for a JSON-serializable object.
|
||||||
|
|
|
@ -458,8 +458,8 @@ class ProjectConfigSchema(BaseModel):
|
||||||
vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands")
|
vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands")
|
||||||
env: Dict[StrictStr, Any] = Field({}, title="Optional variable names to substitute in commands, mapped to environment variable names")
|
env: Dict[StrictStr, Any] = Field({}, title="Optional variable names to substitute in commands, mapped to environment variable names")
|
||||||
assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets")
|
assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets")
|
||||||
workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
|
workflows: Dict[StrictStr, List[Union[StrictStr, List[StrictStr]]]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
|
||||||
commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts")
|
commands: List[ProjectConfigCommand] = Field([], title="Project command shortcuts")
|
||||||
title: Optional[str] = Field(None, title="Project title")
|
title: Optional[str] = Field(None, title="Project title")
|
||||||
spacy_version: Optional[StrictStr] = Field(None, title="spaCy version range that the project is compatible with")
|
spacy_version: Optional[StrictStr] = Field(None, title="spaCy version range that the project is compatible with")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
|
@ -421,6 +421,42 @@ def test_project_config_interpolation(int_value):
|
||||||
substitute_project_variables(project)
|
substitute_project_variables(project)
|
||||||
|
|
||||||
|
|
||||||
|
def test_project_config_multiprocessing_good_case():
|
||||||
|
project = {
|
||||||
|
"commands": [
|
||||||
|
{"name": "command1", "script": ["echo", "command1"]},
|
||||||
|
{"name": "command2", "script": ["echo", "command2"]},
|
||||||
|
{"name": "command3", "script": ["echo", "command3"]},
|
||||||
|
],
|
||||||
|
"workflows": {"all": ["command1", ["command2", "command3"]]},
|
||||||
|
}
|
||||||
|
with make_tempdir() as d:
|
||||||
|
srsly.write_yaml(d / "project.yml", project)
|
||||||
|
load_project_config(d)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"workflows",
|
||||||
|
[
|
||||||
|
{"all": ["command1", ["command2"], "command3"]},
|
||||||
|
{"all": ["command1", ["command2", "command4"]]},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_project_config_multiprocessing_bad_case(workflows):
|
||||||
|
project = {
|
||||||
|
"commands": [
|
||||||
|
{"name": "command1", "script": ["echo", "command1"]},
|
||||||
|
{"name": "command2", "script": ["echo", "command2"]},
|
||||||
|
{"name": "command3", "script": ["echo", "command3"]},
|
||||||
|
],
|
||||||
|
"workflows": workflows,
|
||||||
|
}
|
||||||
|
with make_tempdir() as d:
|
||||||
|
srsly.write_yaml(d / "project.yml", project)
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
load_project_config(d)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"greeting",
|
"greeting",
|
||||||
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],
|
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user