mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Changes after internal discussions
This commit is contained in:
parent
2eb13f2656
commit
9e665f9ad2
|
@ -157,6 +157,7 @@ def load_project_config(
|
|||
print("\n".join(errors))
|
||||
sys.exit(1)
|
||||
validate_project_version(config)
|
||||
validate_max_parallel_processes(config)
|
||||
validate_project_commands(config)
|
||||
# Make sure directories defined in config exist
|
||||
for subdir in config.get("directories", []):
|
||||
|
@ -199,7 +200,7 @@ def substitute_project_variables(
|
|||
|
||||
|
||||
def validate_project_version(config: Dict[str, Any]) -> None:
|
||||
"""If the project defines a compatible spaCy version range, chec that it's
|
||||
"""If the project defines a compatible spaCy version range, check that it's
|
||||
compatible with the current version of spaCy.
|
||||
|
||||
config (Dict[str, Any]): The loaded config.
|
||||
|
@ -215,6 +216,21 @@ def validate_project_version(config: Dict[str, Any]) -> None:
|
|||
msg.fail(err, exits=1)
|
||||
|
||||
|
||||
def validate_max_parallel_processes(config: Dict[str, Any]) -> None:
|
||||
"""If the project defines a maximum number of parallel processes, check that the
|
||||
value is within the permitted range.
|
||||
|
||||
config (Dict[str, Any]): The loaded config.
|
||||
"""
|
||||
max_parallel_processes = config.get("max_parallel_processes", None)
|
||||
if max_parallel_processes is not None and max_parallel_processes < 2:
|
||||
err = (
|
||||
f"The {PROJECT_FILE} specifies a value for max_parallel_processes ({max_parallel_processes}) "
|
||||
f"that is less than 2."
|
||||
)
|
||||
msg.fail(err, exits=1)
|
||||
|
||||
|
||||
def verify_workflow_step(workflow_name: str, commands: List[str], step: str) -> None:
|
||||
if step not in commands:
|
||||
msg.fail(
|
||||
|
@ -246,18 +262,20 @@ def validate_project_commands(config: Dict[str, Any]) -> None:
|
|||
if isinstance(workflow_item, str):
|
||||
verify_workflow_step(workflow_name, commands, workflow_item)
|
||||
else:
|
||||
assert isinstance(workflow_item, list)
|
||||
assert isinstance(workflow_item[0], str)
|
||||
steps = cast(List[str], workflow_item)
|
||||
assert isinstance(workflow_item, dict)
|
||||
assert len(workflow_item) == 1
|
||||
steps_list = workflow_item["parallel"]
|
||||
assert isinstance(steps_list[0], str)
|
||||
steps = cast(List[str], steps_list)
|
||||
if len(steps) < 2:
|
||||
msg.fail(
|
||||
f"Invalid multiprocessing group within '{workflow_name}'.",
|
||||
f"A multiprocessing group must reference at least two commands.",
|
||||
f"Invalid parallel group within '{workflow_name}'.",
|
||||
f"A parallel group must reference at least two commands.",
|
||||
exits=1,
|
||||
)
|
||||
if len(steps) != len(set(steps)):
|
||||
msg.fail(
|
||||
f"A multiprocessing group within '{workflow_name}' contains a command more than once.",
|
||||
f"A parallel group within '{workflow_name}' contains a command more than once.",
|
||||
f"This is not permitted because it is then not possible to determine when to rerun.",
|
||||
exits=1,
|
||||
)
|
||||
|
@ -580,15 +598,3 @@ def setup_gpu(use_gpu: int, silent=None) -> None:
|
|||
local_msg.info("Using CPU")
|
||||
if has_cupy and gpu_is_available():
|
||||
local_msg.info("To switch to GPU 0, use the option: --gpu-id 0")
|
||||
|
||||
|
||||
def get_workflow_steps(workflow_items: List[Union[str, List[str]]]) -> List[str]:
|
||||
steps: List[str] = []
|
||||
for workflow_item in workflow_items:
|
||||
if isinstance(workflow_item, str):
|
||||
steps.append(workflow_item)
|
||||
else:
|
||||
assert isinstance(workflow_item, list)
|
||||
assert isinstance(workflow_item[0], str)
|
||||
steps.extend(workflow_item)
|
||||
return steps
|
||||
|
|
|
@ -77,7 +77,9 @@ def project_document(
|
|||
rendered_steps.append(md.code(step))
|
||||
else:
|
||||
rendered_steps.append(
|
||||
"[" + ", ".join(md.code(p_step) for p_step in step) + "]"
|
||||
"["
|
||||
+ ", ".join(md.code(p_step) for p_step in step["parallel"])
|
||||
+ "]"
|
||||
)
|
||||
data.append([md.code(n), " → ".join(rendered_steps)])
|
||||
if data:
|
||||
|
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
from wasabi import msg
|
||||
|
||||
from .._util import PROJECT_FILE, load_project_config, get_hash, project_cli
|
||||
from .._util import Arg, Opt, NAME, COMMAND, get_workflow_steps
|
||||
from .._util import Arg, Opt, NAME, COMMAND
|
||||
from ...util import working_dir, split_command, join_command, run_command
|
||||
from ...util import SimpleFrozenList
|
||||
|
||||
|
@ -106,7 +106,12 @@ def update_dvc_config(
|
|||
dvc_commands = []
|
||||
config_commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
|
||||
processed_step = False
|
||||
for name in get_workflow_steps(workflows[workflow]):
|
||||
for name in workflows[workflow]:
|
||||
if isinstance(name, dict) and "parallel" in name:
|
||||
msg.fail(
|
||||
f"A DVC workflow may not contain parallel groups",
|
||||
exits=1,
|
||||
)
|
||||
command = config_commands[name]
|
||||
deps = command.get("deps", [])
|
||||
outputs = command.get("outputs", [])
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional, List, Dict, Sequence, Any, Iterable, Union, Tuple
|
||||
from pathlib import Path
|
||||
from multiprocessing import Process, Lock
|
||||
from multiprocessing import Process, Lock, Queue
|
||||
from multiprocessing.synchronize import Lock as Lock_t
|
||||
from wasabi import msg
|
||||
from wasabi.util import locale_escape
|
||||
|
@ -15,7 +15,6 @@ from ...util import SimpleFrozenList, is_minor_version_match, ENV_VARS
|
|||
from ...util import check_bool_env_var, SimpleFrozenDict
|
||||
from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash
|
||||
from .._util import get_checksum, project_cli, Arg, Opt, COMMAND, parse_config_overrides
|
||||
from .._util import get_workflow_steps
|
||||
|
||||
|
||||
@project_cli.command(
|
||||
|
@ -54,6 +53,7 @@ def project_run(
|
|||
dry: bool = False,
|
||||
capture: bool = False,
|
||||
mult_group_mutex: Optional[Lock_t] = None,
|
||||
completion_queue: Optional[Queue] = None,
|
||||
) -> None:
|
||||
"""Run a named script defined in the project.yml. If the script is part
|
||||
of the default pipeline (defined in the "run" section), DVC is used to
|
||||
|
@ -76,6 +76,7 @@ def project_run(
|
|||
config = load_project_config(project_dir, overrides=overrides)
|
||||
commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
|
||||
workflows = config.get("workflows", {})
|
||||
max_parallel_processes = config.get("max_parallel_processes")
|
||||
validate_subcommand(list(commands.keys()), list(workflows.keys()), subcommand)
|
||||
if subcommand in workflows:
|
||||
msg.info(f"Running workflow '{subcommand}'")
|
||||
|
@ -91,8 +92,11 @@ def project_run(
|
|||
mult_group_mutex=mult_group_mutex,
|
||||
)
|
||||
else:
|
||||
assert isinstance(workflow_item, list)
|
||||
assert isinstance(workflow_item[0], str)
|
||||
assert isinstance(workflow_item, dict)
|
||||
assert len(workflow_item) == 1
|
||||
steps_list = workflow_item["parallel"]
|
||||
assert isinstance(steps_list[0], str)
|
||||
completion_queue = Queue(len(steps_list))
|
||||
processes = [
|
||||
Process(
|
||||
target=project_run,
|
||||
|
@ -103,14 +107,26 @@ def project_run(
|
|||
"dry": dry,
|
||||
"capture": capture,
|
||||
"mult_group_mutex": mult_group_mutex,
|
||||
"completion_queue": completion_queue,
|
||||
},
|
||||
)
|
||||
for cmd in workflow_item
|
||||
for cmd in steps_list
|
||||
]
|
||||
for process in processes:
|
||||
process.start()
|
||||
for process in processes:
|
||||
process.join()
|
||||
num_processes = len(processes)
|
||||
if (
|
||||
max_parallel_processes is not None
|
||||
and max_parallel_processes < num_processes
|
||||
):
|
||||
num_processes = max_parallel_processes
|
||||
process_iterator = iter(processes)
|
||||
for _ in range(num_processes):
|
||||
next(process_iterator).start()
|
||||
for _ in range(len(steps_list)):
|
||||
completion_queue.get()
|
||||
next_process = next(process_iterator, None)
|
||||
if next_process is not None:
|
||||
next_process.start()
|
||||
|
||||
else:
|
||||
cmd = commands[subcommand]
|
||||
for dep in cmd.get("deps", []):
|
||||
|
@ -134,6 +150,8 @@ def project_run(
|
|||
run_commands(cmd["script"], dry=dry, capture=capture)
|
||||
if not dry:
|
||||
update_lockfile(current_dir, cmd, mult_group_mutex=mult_group_mutex)
|
||||
if completion_queue is not None:
|
||||
completion_queue.put(None)
|
||||
|
||||
|
||||
def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
|
||||
|
@ -157,12 +175,36 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
|
|||
if help_text:
|
||||
print(f"\n{help_text}\n")
|
||||
elif subcommand in workflows:
|
||||
steps = get_workflow_steps(workflows[subcommand])
|
||||
steps: List[Tuple[str, str]] = []
|
||||
contains_parallel = False
|
||||
for workflow_item in workflows[subcommand]:
|
||||
if isinstance(workflow_item, str):
|
||||
steps.append((" ", workflow_item))
|
||||
else:
|
||||
contains_parallel = True
|
||||
assert isinstance(workflow_item, dict)
|
||||
assert len(workflow_item) == 1
|
||||
steps_list = workflow_item["parallel"]
|
||||
assert isinstance(steps_list[0], str)
|
||||
for i, step in enumerate(steps_list):
|
||||
if i == 0:
|
||||
parallel_char = "╔"
|
||||
elif i + 1 == len(steps_list):
|
||||
parallel_char = "╚"
|
||||
else:
|
||||
parallel_char = "║"
|
||||
steps.append((parallel_char, step))
|
||||
print(f"\nWorkflow consisting of {len(steps)} commands:")
|
||||
steps_data = [
|
||||
(f"{i + 1}. {step}", commands[step].get("help", ""))
|
||||
for i, step in enumerate(steps)
|
||||
]
|
||||
if contains_parallel:
|
||||
steps_data = [
|
||||
(f"{i + 1}. {step[0]} {step[1]}", commands[step[1]].get("help", ""))
|
||||
for i, step in enumerate(steps)
|
||||
]
|
||||
else:
|
||||
steps_data = [
|
||||
(f"{i + 1}. {step[1]}", commands[step[1]].get("help", ""))
|
||||
for i, step in enumerate(steps)
|
||||
]
|
||||
msg.table(steps_data)
|
||||
help_cmd = f"{COMMAND} project run [COMMAND] {project_loc} --help"
|
||||
print(f"For command details, run: {help_cmd}")
|
||||
|
@ -180,9 +222,16 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
|
|||
print(f"Usage: {COMMAND} project run [WORKFLOW] {project_loc}")
|
||||
table_entries: List[Tuple[str, str]] = []
|
||||
for name, workflow_items in workflows.items():
|
||||
table_entries.append(
|
||||
(name, " -> ".join(get_workflow_steps(workflow_items)))
|
||||
)
|
||||
descriptions: List[str] = []
|
||||
for workflow_item in workflow_items:
|
||||
if isinstance(workflow_item, str):
|
||||
descriptions.append(workflow_item)
|
||||
else:
|
||||
assert isinstance(workflow_item, dict)
|
||||
assert len(workflow_item) == 1
|
||||
steps_list = workflow_item["parallel"]
|
||||
descriptions.append("parallel[" + ", ".join(steps_list) + "]")
|
||||
table_entries.append((name, " -> ".join(descriptions)))
|
||||
msg.table(table_entries)
|
||||
|
||||
|
||||
|
|
|
@ -458,10 +458,11 @@ class ProjectConfigSchema(BaseModel):
|
|||
vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands")
|
||||
env: Dict[StrictStr, Any] = Field({}, title="Optional variable names to substitute in commands, mapped to environment variable names")
|
||||
assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets")
|
||||
workflows: Dict[StrictStr, List[Union[StrictStr, List[StrictStr]]]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
|
||||
workflows: Dict[StrictStr, List[Union[StrictStr, Dict[Literal["parallel"], List[StrictStr]]]]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
|
||||
commands: List[ProjectConfigCommand] = Field([], title="Project command shortcuts")
|
||||
title: Optional[str] = Field(None, title="Project title")
|
||||
spacy_version: Optional[StrictStr] = Field(None, title="spaCy version range that the project is compatible with")
|
||||
max_parallel_processes: Optional[int] = Field(None, title="Maximum number of permitted parallel processes")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
|
|
|
@ -429,7 +429,7 @@ def test_project_config_multiprocessing_good_case():
|
|||
{"name": "command2", "script": ["echo", "command2"]},
|
||||
{"name": "command3", "script": ["echo", "command3"]},
|
||||
],
|
||||
"workflows": {"all": ["command1", ["command2", "command3"]]},
|
||||
"workflows": {"all": ["command1", {"parallel": ["command2", "command3"]}]},
|
||||
}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
|
@ -439,12 +439,13 @@ def test_project_config_multiprocessing_good_case():
|
|||
@pytest.mark.parametrize(
|
||||
"workflows",
|
||||
[
|
||||
{"all": ["command1", ["command2"], "command3"]},
|
||||
{"all": ["command1", ["command2", "command4"]]},
|
||||
{"all": ["command1", ["command2", "command2"]]},
|
||||
{"all": ["command1", {"parallel": ["command2"]}, "command3"]},
|
||||
{"all": ["command1", {"parallel": ["command2", "command4"]}]},
|
||||
{"all": ["command1", {"parallel": ["command2", "command2"]}]},
|
||||
{"all": ["command1", {"serial": ["command2", "command3"]}]},
|
||||
],
|
||||
)
|
||||
def test_project_config_multiprocessing_bad_case(workflows):
|
||||
def test_project_config_multiprocessing_bad_case_workflows(workflows):
|
||||
project = {
|
||||
"commands": [
|
||||
{"name": "command1", "script": ["echo", "command1"]},
|
||||
|
@ -459,6 +460,33 @@ def test_project_config_multiprocessing_bad_case(workflows):
|
|||
load_project_config(d)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_parallel_processes", [-1, 0, 1])
|
||||
def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_processes):
|
||||
with make_tempdir() as d:
|
||||
project = {
|
||||
"max_parallel_processes": max_parallel_processes,
|
||||
"commands": [
|
||||
{
|
||||
"name": "commandA",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "A"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandB",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "B"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandC",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "C"))))],
|
||||
},
|
||||
],
|
||||
"workflows": {"all": [{"parallel": ["commandA", "commandB", "commandC"]}]},
|
||||
}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
with pytest.raises(SystemExit):
|
||||
load_project_config(d)
|
||||
|
||||
|
||||
def test_project_run_multiprocessing_good_case():
|
||||
with make_tempdir() as d:
|
||||
|
||||
|
@ -502,7 +530,12 @@ else: # should never happen because of skipping
|
|||
"outputs": [os.sep.join((str(d), "e"))],
|
||||
},
|
||||
],
|
||||
"workflows": {"all": [["commandA", "commandB"], ["commandA", "commandB"]]},
|
||||
"workflows": {
|
||||
"all": [
|
||||
{"parallel": ["commandA", "commandB"]},
|
||||
{"parallel": ["commandB", "commandA"]},
|
||||
]
|
||||
},
|
||||
}
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
load_project_config(d)
|
||||
|
@ -513,6 +546,36 @@ else: # should never happen because of skipping
|
|||
assert not os.path.exists(os.sep.join((str(d), "f")))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_parallel_processes", [2, 3, 4])
|
||||
def test_project_run_multiprocessing_max_processes_good_case(max_parallel_processes):
|
||||
with make_tempdir() as d:
|
||||
|
||||
project = {
|
||||
"max_parallel_processes": max_parallel_processes,
|
||||
"commands": [
|
||||
{
|
||||
"name": "commandA",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "A"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandB",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "B"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandC",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "C"))))],
|
||||
},
|
||||
],
|
||||
"workflows": {"all": [{"parallel": ["commandA", "commandB", "commandC"]}]},
|
||||
}
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
load_project_config(d)
|
||||
project_run(d, "all")
|
||||
assert os.path.exists(os.sep.join((str(d), "A")))
|
||||
assert os.path.exists(os.sep.join((str(d), "B")))
|
||||
assert os.path.exists(os.sep.join((str(d), "C")))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"greeting",
|
||||
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],
|
||||
|
|
|
@ -1455,9 +1455,8 @@ Auto-generate [Data Version Control](https://dvc.org) (DVC) config file. Calls
|
|||
the hood to generate the `dvc.yaml`. A DVC project can only define one pipeline,
|
||||
so you need to specify one workflow defined in the
|
||||
[`project.yml`](/usage/projects#project-yml). If no workflow is specified, the
|
||||
first defined workflow is used. Note that any multiprocessing groups in the spaCy
|
||||
config file will be flattened out and defined for sequential execution in the DVC config file
|
||||
as DVC does not support multiprocessing in the same way as spaCy. The DVC config will only be updated
|
||||
first defined workflow is used. Note that the spaCy config file may not contain parallel groups,
|
||||
as DVC does not support parallel execution in the same way as spaCy. The DVC config will only be updated
|
||||
if the `project.yml` changed. For details, see the
|
||||
[DVC integration](/usage/projects#dvc) docs.
|
||||
|
||||
|
|
|
@ -153,9 +153,7 @@ script).
|
|||
> all:
|
||||
> - preprocess
|
||||
> - train
|
||||
> -
|
||||
> - multiprocessingGroupCommand1
|
||||
> - multiprocessingGroupCommand2
|
||||
> - parallel: [parallelCommand1, parallelCommand2]
|
||||
> - package
|
||||
> ```
|
||||
|
||||
|
@ -172,8 +170,8 @@ $ python -m spacy project run all
|
|||
```
|
||||
|
||||
Sometimes it makes sense to execute two or more commands in parallel. A group
|
||||
of commands executed at once is known as a multiprocessing group; a multiprocessing group
|
||||
is defined by indenting the commands it contains. You are responsible for making sure that no
|
||||
of commands executed in parallel is defined using the `parallel` keyword mapping to
|
||||
the commands specified as a list. You are responsible for making sure that no
|
||||
deadlocks, race conditions or other issues can arise from the parallel execution.
|
||||
|
||||
Using the expected [dependencies and outputs](#deps-outputs) defined in the
|
||||
|
@ -239,7 +237,7 @@ pipelines.
|
|||
| `env` | A dictionary of variables, mapped to the names of environment variables that will be read in when running the project. For example, `${env.name}` will use the value of the environment variable defined as `name`. |
|
||||
| `directories` | An optional list of [directories](#project-files) that should be created in the project for assets, training outputs, metrics etc. spaCy will make sure that these directories always exist. |
|
||||
| `assets` | A list of assets that can be fetched with the [`project assets`](/api/cli#project-assets) command. `url` defines a URL or local path, `dest` is the destination file relative to the project directory, and an optional `checksum` ensures that an error is raised if the file's checksum doesn't match. Instead of `url`, you can also provide a `git` block with the keys `repo`, `branch` and `path`, to download from a Git repo. |
|
||||
| `workflows` | A dictionary of workflow names, mapped to a list of command names, to execute in order. Nested lists represent groups of commands to execute concurrently. Workflows can be run with the [`project run`](/api/cli#project-run) command. |
|
||||
| `workflows` | A dictionary of workflow names, mapped to a list of command names, to execute in order. The `parallel` keyword mapping to a list of command names specifies parallel execution. Workflows can be run with the [`project run`](/api/cli#project-run) command. |
|
||||
| `commands` | A list of named commands. A command can define an optional help message (shown in the CLI when the user adds `--help`) and the `script`, a list of commands to run. The `deps` and `outputs` let you define the created file the command depends on and produces, respectively. This lets spaCy determine whether a command needs to be re-run because its dependencies or outputs changed. Commands can be run as part of a workflow, or separately with the [`project run`](/api/cli#project-run) command. |
|
||||
| `spacy_version` | Optional spaCy version range like `>=3.0.0,<3.1.0` that the project is compatible with. If it's loaded with an incompatible version, an error is raised when the project is loaded. |
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user