More updates based on review comments

This commit is contained in:
richardpaulhudson 2022-07-25 15:33:46 +02:00
parent 6ac15ad507
commit 4eb61a71c8
3 changed files with 185 additions and 129 deletions

View File

@ -266,11 +266,7 @@ def validate_project_commands(config: Dict[str, Any]) -> None:
if isinstance(workflow_item, str): if isinstance(workflow_item, str):
verify_workflow_step(workflow_name, commands, workflow_item) verify_workflow_step(workflow_name, commands, workflow_item)
else: else:
assert isinstance(workflow_item, dict) steps = cast(List[str], workflow_item["parallel"])
assert len(workflow_item) == 1
steps_list = workflow_item["parallel"]
assert isinstance(steps_list[0], str)
steps = cast(List[str], steps_list)
if len(steps) < 2: if len(steps) < 2:
msg.fail( msg.fail(
f"Invalid parallel group within '{workflow_name}'.", f"Invalid parallel group within '{workflow_name}'.",

View File

@ -1,22 +1,3 @@
from typing import Any, List, Optional, Dict, Union, cast
import os
import sys
from pathlib import Path
from time import time
from multiprocessing import Manager, Queue, Process, get_context
from multiprocessing.context import SpawnProcess
from shutil import rmtree
from signal import SIGTERM
from subprocess import STDOUT, Popen, TimeoutExpired
from dataclasses import dataclass, field
from wasabi import msg
from wasabi.util import color, supports_ansi
from .._util import check_rerun, check_deps, update_lockfile, load_project_config
from ...util import SimpleFrozenDict, working_dir, split_command, join_command
from ...util import check_bool_env_var, ENV_VARS
from ...errors import Errors
""" """
Permits the execution of a parallel command group. Permits the execution of a parallel command group.
@ -30,6 +11,25 @@ of the states is documented alongside the _ParallelCommandInfo.STATES code. Note
between the states 'failed' and 'terminated' is not meaningful on Windows, so that both are displayed between the states 'failed' and 'terminated' is not meaningful on Windows, so that both are displayed
as 'failed/terminated' on Windows systems. as 'failed/terminated' on Windows systems.
""" """
from typing import Any, List, Optional, Dict, Union, cast, Iterator
import os
import sys
import queue
from pathlib import Path
from time import time
from multiprocessing import Manager, Queue, get_context
from multiprocessing.context import SpawnProcess
from shutil import rmtree
from signal import SIGTERM
from subprocess import STDOUT, Popen, TimeoutExpired
from dataclasses import dataclass, field
from wasabi import msg
from wasabi.util import color, supports_ansi
from .._util import check_rerun, check_deps, update_lockfile, load_project_config
from ...util import SimpleFrozenDict, working_dir, split_command, join_command
from ...util import check_bool_env_var, ENV_VARS
from ...errors import Errors
# Use spawn to create worker processes on all OSs for consistency # Use spawn to create worker processes on all OSs for consistency
mp_context = get_context("spawn") mp_context = get_context("spawn")
@ -145,7 +145,6 @@ def project_run_parallel_group(
config = load_project_config(project_dir, overrides=overrides) config = load_project_config(project_dir, overrides=overrides)
commands = {cmd["name"]: cmd for cmd in config.get("commands", [])} commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION) check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION)
parallel_group_status_queue = Manager().Queue()
max_parallel_processes = config.get("max_parallel_processes") max_parallel_processes = config.get("max_parallel_processes")
cmd_infos = [ cmd_infos = [
_ParallelCommandInfo(cmd_name, commands[cmd_name], cmd_index) _ParallelCommandInfo(cmd_name, commands[cmd_name], cmd_index)
@ -153,6 +152,27 @@ def project_run_parallel_group(
] ]
with working_dir(project_dir) as current_dir: with working_dir(project_dir) as current_dir:
rmtree(PARALLEL_LOGGING_DIR_NAME, ignore_errors=True)
os.mkdir(PARALLEL_LOGGING_DIR_NAME)
divider_parallel_descriptor = parallel_descriptor = (
"parallel[" + ", ".join(cmd_info.cmd_name for cmd_info in cmd_infos) + "]"
)
if len(divider_parallel_descriptor) > MAX_WIDTH_DIVIDER:
divider_parallel_descriptor = (
divider_parallel_descriptor[: (MAX_WIDTH_DIVIDER - 3)] + "..."
)
msg.divider(divider_parallel_descriptor)
if not DISPLAY_STATUS_TABLE and len(parallel_descriptor) > MAX_WIDTH_DIVIDER:
# reprint the descriptor if it was too long and had to be cut short
print(parallel_descriptor)
msg.info(
"Temporary logs are being written to "
+ os.sep.join((str(current_dir), PARALLEL_LOGGING_DIR_NAME))
)
parallel_group_status_queue = Manager().Queue()
for cmd_info in cmd_infos: for cmd_info in cmd_infos:
check_deps(cmd_info.cmd, cmd_info.cmd_name, project_dir, dry) check_deps(cmd_info.cmd, cmd_info.cmd_name, project_dir, dry)
if ( if (
@ -162,8 +182,6 @@ def project_run_parallel_group(
and not force and not force
): ):
cmd_info.change_state("not rerun") cmd_info.change_state("not rerun")
rmtree(PARALLEL_LOGGING_DIR_NAME, ignore_errors=True)
os.mkdir(PARALLEL_LOGGING_DIR_NAME)
worker_processes: List[SpawnProcess] = [] worker_processes: List[SpawnProcess] = []
proc_to_cmd_infos: Dict[SpawnProcess, _ParallelCommandInfo] = {} proc_to_cmd_infos: Dict[SpawnProcess, _ParallelCommandInfo] = {}
num_concurr_worker_processes = 0 num_concurr_worker_processes = 0
@ -181,110 +199,25 @@ def project_run_parallel_group(
) )
worker_processes.append(worker_process) worker_processes.append(worker_process)
proc_to_cmd_infos[worker_process] = cmd_info proc_to_cmd_infos[worker_process] = cmd_info
num_concurr_worker_processes = len(worker_processes) num_concurr_worker_processes = len(worker_processes)
if ( if (
max_parallel_processes is not None max_parallel_processes is not None
and max_parallel_processes < num_concurr_worker_processes and max_parallel_processes < num_concurr_worker_processes
): ):
num_concurr_worker_processes = max_parallel_processes num_concurr_worker_processes = max_parallel_processes
worker_process_iterator = iter(worker_processes) worker_process_iterator = iter(worker_processes)
for _ in range(num_concurr_worker_processes): for _ in range(num_concurr_worker_processes):
_start_worker_process(next(worker_process_iterator), proc_to_cmd_infos) _start_worker_process(next(worker_process_iterator), proc_to_cmd_infos)
divider_parallel_descriptor = parallel_descriptor = ( _process_worker_status_messages(
"parallel[" + ", ".join(cmd_info.cmd_name for cmd_info in cmd_infos) + "]" cmd_infos,
proc_to_cmd_infos,
parallel_group_status_queue,
worker_process_iterator,
current_dir,
dry,
) )
if len(divider_parallel_descriptor) > MAX_WIDTH_DIVIDER:
divider_parallel_descriptor = (
divider_parallel_descriptor[: (MAX_WIDTH_DIVIDER - 3)] + "..."
)
msg.divider(divider_parallel_descriptor)
if not DISPLAY_STATUS_TABLE and len(parallel_descriptor) > MAX_WIDTH_DIVIDER:
# reprint the descriptor if it was too long and had to be cut short
print(parallel_descriptor)
msg.info(
"Temporary logs are being written to "
+ os.sep.join((str(current_dir), PARALLEL_LOGGING_DIR_NAME))
)
status_table_not_yet_displayed = True
while any(
cmd_info.state.name in ("pending", "starting", "running")
for cmd_info in cmd_infos
):
try:
mess: Dict[str, Union[str, int]] = parallel_group_status_queue.get(
timeout=PARALLEL_GROUP_STATUS_INTERVAL * 20
)
except Exception:
# No more messages are being received: the whole group has hung
for other_cmd_info in (
c for c in cmd_infos if c.state.name in ("starting", "running")
):
other_cmd_info.change_state("hung")
for other_cmd_info in (
c for c in cmd_infos if c.state.name == "pending"
):
other_cmd_info.change_state("cancelled")
break
cmd_info = cmd_infos[cast(int, mess["cmd_index"])]
if mess["type"] in ("started", "keepalive"):
cmd_info.last_keepalive_time = int(time())
for other_cmd_info in (
c for c in cmd_infos if c.state.name in ("starting", "running")
):
if (
other_cmd_info.last_keepalive_time is not None
and time() - other_cmd_info.last_keepalive_time
> PARALLEL_GROUP_STATUS_INTERVAL * 20
):
# a specific command has hung
other_cmd_info.change_state("hung")
if mess["type"] == "started":
cmd_info.change_state("running")
cmd_info.running_os_cmd_index = cast(int, mess["os_cmd_index"])
cmd_info.pid = cast(int, mess["pid"])
if mess["type"] == "completed":
cmd_info.rc = cast(int, mess["rc"])
if cmd_info.rc == 0:
cmd_info.change_state("succeeded")
if not dry:
update_lockfile(current_dir, cmd_info.cmd)
next_worker_process = next(worker_process_iterator, None)
if next_worker_process is not None:
_start_worker_process(next_worker_process, proc_to_cmd_infos)
elif cmd_info.rc > 0:
cmd_info.change_state("failed")
else:
cmd_info.change_state("terminated")
cmd_info.console_output = cast(str, mess["console_output"])
if any(
c for c in cmd_infos if c.state.name in ("failed", "terminated", "hung")
):
# a command in the group hasn't succeeded, so terminate/cancel the rest
for other_cmd_info in (
c for c in cmd_infos if c.state.name == "running"
):
try:
os.kill(cast(int, other_cmd_info.pid), SIGTERM)
except:
# the subprocess the main process is trying to kill could already
# have completed, and the message from the worker process notifying
# the main process about this could still be in the queue
pass
for other_cmd_info in (
c for c in cmd_infos if c.state.name == "pending"
):
other_cmd_info.change_state("cancelled")
if mess["type"] != "keepalive" and DISPLAY_STATUS_TABLE:
if status_table_not_yet_displayed:
status_table_not_yet_displayed = False
else:
# overwrite the existing status table
print("\033[2K\033[F" * (4 + len(cmd_infos)))
data = [[c.cmd_name, c.state_repr] for c in cmd_infos]
header = ["Command", "Status"]
msg.table(data, header=header)
for cmd_info in (c for c in cmd_infos if c.state.name != "cancelled"): for cmd_info in (c for c in cmd_infos if c.state.name != "cancelled"):
msg.divider(cmd_info.cmd_name) msg.divider(cmd_info.cmd_name)
if cmd_info.state.name == "not rerun": if cmd_info.state.name == "not rerun":
@ -305,6 +238,98 @@ def project_run_parallel_group(
sys.exit(-1) sys.exit(-1)
def _process_worker_status_messages(
cmd_infos: List[_ParallelCommandInfo],
proc_to_cmd_infos: Dict[SpawnProcess, _ParallelCommandInfo],
parallel_group_status_queue: queue.Queue,
worker_process_iterator: Iterator[SpawnProcess],
current_dir: Path,
dry: bool,
) -> None:
"""Listens on the status queue and processes messages received from the worker processes.
cmd_infos: a list of info objects about the commands in the parallel group
proc_to_cmd_infos: a dictionary from Process objects to command info objects
parallel_group_status_queue: the status queue
worker_process_iterator: an iterator over the processes, some or all of which
will already have been iterated over and started
current_dir: the current directory
dry (bool): Perform a dry run and don't execute commands.
"""
status_table_not_yet_displayed = True
while any(
cmd_info.state.name in ("pending", "starting", "running")
for cmd_info in cmd_infos
):
try:
mess: Dict[str, Union[str, int]] = parallel_group_status_queue.get(
timeout=PARALLEL_GROUP_STATUS_INTERVAL * 20
)
except Exception:
# No more messages are being received: the whole group has hung
for other_cmd_info in (
c for c in cmd_infos if c.state.name in ("starting", "running")
):
other_cmd_info.change_state("hung")
for other_cmd_info in (c for c in cmd_infos if c.state.name == "pending"):
other_cmd_info.change_state("cancelled")
break
cmd_info = cmd_infos[cast(int, mess["cmd_index"])]
if mess["mess_type"] in ("started", "keepalive"):
cmd_info.last_keepalive_time = int(time())
for other_cmd_info in (
c for c in cmd_infos if c.state.name in ("starting", "running")
):
if (
other_cmd_info.last_keepalive_time is not None
and time() - other_cmd_info.last_keepalive_time
> PARALLEL_GROUP_STATUS_INTERVAL * 20
):
# a specific command has hung
other_cmd_info.change_state("hung")
if mess["mess_type"] == "started":
cmd_info.change_state("running")
cmd_info.running_os_cmd_index = cast(int, mess["os_cmd_index"])
cmd_info.pid = cast(int, mess["pid"])
if mess["mess_type"] == "completed":
cmd_info.rc = cast(int, mess["rc"])
if cmd_info.rc == 0:
cmd_info.change_state("succeeded")
if not dry:
update_lockfile(current_dir, cmd_info.cmd)
next_worker_process = next(worker_process_iterator, None)
if next_worker_process is not None:
_start_worker_process(next_worker_process, proc_to_cmd_infos)
elif cmd_info.rc > 0:
cmd_info.change_state("failed")
else:
cmd_info.change_state("terminated")
cmd_info.console_output = cast(str, mess["console_output"])
if any(
c for c in cmd_infos if c.state.name in ("failed", "terminated", "hung")
):
# a command in the group hasn't succeeded, so terminate/cancel the rest
for other_cmd_info in (c for c in cmd_infos if c.state.name == "running"):
try:
os.kill(cast(int, other_cmd_info.pid), SIGTERM)
except:
# the subprocess the main process is trying to kill could already
# have completed, and the message from the worker process notifying
# the main process about this could still be in the queue
pass
for other_cmd_info in (c for c in cmd_infos if c.state.name == "pending"):
other_cmd_info.change_state("cancelled")
if mess["mess_type"] != "keepalive" and DISPLAY_STATUS_TABLE:
if status_table_not_yet_displayed:
status_table_not_yet_displayed = False
else:
# overwrite the existing status table
print("\033[2K\033[F" * (4 + len(cmd_infos)))
data = [[c.cmd_name, c.state_repr] for c in cmd_infos]
header = ["Command", "Status"]
msg.table(data, header=header)
def _project_run_parallel_cmd( def _project_run_parallel_cmd(
cmd_info: _ParallelCommandInfo, cmd_info: _ParallelCommandInfo,
*, *,
@ -314,9 +339,9 @@ def _project_run_parallel_cmd(
) -> None: ) -> None:
"""Run a single spaCy projects command as a worker process. """Run a single spaCy projects command as a worker process.
Communicates with the main process via queue messages that consist of Communicates with the main process via queue messages whose type is determined
dictionaries whose type is determined by the entry 'type'. Possible by the entry 'mess_type' and that are structured as dictionaries. Possible
values of 'type' are 'started', 'completed' and 'keepalive'. Each dictionary values of 'mess_type' are 'started', 'completed' and 'keepalive'. Each dictionary
type contains different additional fields.""" type contains different additional fields."""
# we can use the command name as a unique log filename because a parallel # we can use the command name as a unique log filename because a parallel
# group is not allowed to contain the same command more than once # group is not allowed to contain the same command more than once
@ -331,6 +356,8 @@ def _project_run_parallel_cmd(
if len(command) and command[0] in ("python", "python3"): if len(command) and command[0] in ("python", "python3"):
# -u: prevent buffering within Python # -u: prevent buffering within Python
command = [sys.executable, "-u", *command[1:]] command = [sys.executable, "-u", *command[1:]]
elif len(command) and command[0].startswith("python3"): # e.g. python3.10
command = [command[0], "-u", *command[1:]]
elif len(command) and command[0] in ("pip", "pip3"): elif len(command) and command[0] in ("pip", "pip3"):
command = [sys.executable, "-m", "pip", *command[1:]] command = [sys.executable, "-m", "pip", *command[1:]]
logfile.write( logfile.write(
@ -364,7 +391,7 @@ def _project_run_parallel_cmd(
break break
parallel_group_status_queue.put( parallel_group_status_queue.put(
{ {
"type": "started", "mess_type": "started",
"cmd_index": cmd_info.cmd_index, "cmd_index": cmd_info.cmd_index,
"os_cmd_index": os_cmd_index, "os_cmd_index": os_cmd_index,
"pid": sp.pid, "pid": sp.pid,
@ -378,7 +405,7 @@ def _project_run_parallel_cmd(
if sp.returncode == None: if sp.returncode == None:
parallel_group_status_queue.put( parallel_group_status_queue.put(
{ {
"type": "keepalive", "mess_type": "keepalive",
"cmd_index": cmd_info.cmd_index, "cmd_index": cmd_info.cmd_index,
} }
) )
@ -403,7 +430,7 @@ def _project_run_parallel_cmd(
else: # dry run else: # dry run
parallel_group_status_queue.put( parallel_group_status_queue.put(
{ {
"type": "started", "mess_type": "started",
"cmd_index": cmd_info.cmd_index, "cmd_index": cmd_info.cmd_index,
"os_cmd_index": os_cmd_index, "os_cmd_index": os_cmd_index,
"pid": -1, "pid": -1,
@ -419,7 +446,7 @@ def _project_run_parallel_cmd(
rc = sp.returncode rc = sp.returncode
parallel_group_status_queue.put( parallel_group_status_queue.put(
{ {
"type": "completed", "mess_type": "completed",
"cmd_index": cmd_info.cmd_index, "cmd_index": cmd_info.cmd_index,
"rc": rc, "rc": rc,
"console_output": logfile.read(), "console_output": logfile.read(),

View File

@ -8,6 +8,7 @@ from .util import make_tempdir
def test_project_config_multiprocessing_good_case(): def test_project_config_multiprocessing_good_case():
"""Confirms that a correct parallel command group is loaded from the project file."""
project = { project = {
"commands": [ "commands": [
{"name": "command1", "script": ["echo", "command1"]}, {"name": "command1", "script": ["echo", "command1"]},
@ -31,6 +32,14 @@ def test_project_config_multiprocessing_good_case():
], ],
) )
def test_project_config_multiprocessing_bad_case_workflows(workflows): def test_project_config_multiprocessing_bad_case_workflows(workflows):
"""Confirms that parallel command groups that do not obey various rules
result in errors:
- a parallel group with only one command
- a parallel group with a non-existent command
- a parallel group containing the same command more than once
- a parallel group whose dictionary name is not 'parallel'
"""
project = { project = {
"commands": [ "commands": [
{"name": "command1", "script": ["echo", "command1"]}, {"name": "command1", "script": ["echo", "command1"]},
@ -47,6 +56,13 @@ def test_project_config_multiprocessing_bad_case_workflows(workflows):
@pytest.mark.parametrize("max_parallel_processes", [-1, 0, 1]) @pytest.mark.parametrize("max_parallel_processes", [-1, 0, 1])
def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_processes): def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_processes):
"""Confirms that invalid max_parallel_processes values are not accepted:
- negative value
- zero
- 1 (because the commands in the group can then be executed serially)
"""
with make_tempdir() as d: with make_tempdir() as d:
project = { project = {
"max_parallel_processes": max_parallel_processes, "max_parallel_processes": max_parallel_processes,
@ -73,6 +89,13 @@ def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_proc
def test_project_run_multiprocessing_good_case(): def test_project_run_multiprocessing_good_case():
"""Executes a parallel command group. The script pscript.py writes to a file, then polls until
a second file exists. It is executed twice by commands in a parallel group that write
and read opposite files, thus proving that the commands are being executed simultaneously. It
then checks if a file f exists, writes to f if f does not already exist and writes to a
second file if f does already exist. Although the workflow is run twice, the fact that both
commands should be skipped on the second run mean that these second files should never be written.
"""
with make_tempdir() as d: with make_tempdir() as d:
pscript = """ pscript = """
@ -106,6 +129,7 @@ else: # should never happen because of skipping
" ".join(("python", pscript_loc, str(d), "a", "b", "c", "d")) " ".join(("python", pscript_loc, str(d), "a", "b", "c", "d"))
], ],
"outputs": [os.sep.join((str(d), "f"))], "outputs": [os.sep.join((str(d), "f"))],
# output should work even if file doesn't exist
}, },
{ {
"name": "commandB", "name": "commandB",
@ -133,6 +157,9 @@ else: # should never happen because of skipping
@pytest.mark.parametrize("max_parallel_processes", [2, 3, 4, 5, 6]) @pytest.mark.parametrize("max_parallel_processes", [2, 3, 4, 5, 6])
def test_project_run_multiprocessing_max_processes_good_case(max_parallel_processes): def test_project_run_multiprocessing_max_processes_good_case(max_parallel_processes):
"""Confirms that no errors are thrown when a parallel group is executed
with various values of max_parallel_processes.
"""
with make_tempdir() as d: with make_tempdir() as d:
project = { project = {
@ -185,6 +212,12 @@ def test_project_run_multiprocessing_max_processes_good_case(max_parallel_proces
@pytest.mark.parametrize("failing_command", ["python", "someNotExistentOSCommand"]) @pytest.mark.parametrize("failing_command", ["python", "someNotExistentOSCommand"])
def test_project_run_multiprocessing_failure(failing_command: str): def test_project_run_multiprocessing_failure(failing_command: str):
"""Confirms that when one command in a parallel group fails, the other commands
are terminated. The script has two arguments: the number of seconds for which
to sleep, and the return code. One command in a group is terminated immediately
with a non-zero return code, the other two commands after several seconds with
zero return codes. Measuring the execution length for the whole group shows
whether or not the sleeping processes were successfully terminated."""
with make_tempdir() as d: with make_tempdir() as d:
pscript = """ pscript = """