diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index 7a6ef41c7..129bb4c0e 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -266,11 +266,7 @@ def validate_project_commands(config: Dict[str, Any]) -> None: if isinstance(workflow_item, str): verify_workflow_step(workflow_name, commands, workflow_item) else: - assert isinstance(workflow_item, dict) - assert len(workflow_item) == 1 - steps_list = workflow_item["parallel"] - assert isinstance(steps_list[0], str) - steps = cast(List[str], steps_list) + steps = cast(List[str], workflow_item["parallel"]) if len(steps) < 2: msg.fail( f"Invalid parallel group within '{workflow_name}'.", diff --git a/spacy/cli/project/parallel.py b/spacy/cli/project/parallel.py index 1d3ad675d..dab563485 100644 --- a/spacy/cli/project/parallel.py +++ b/spacy/cli/project/parallel.py @@ -1,22 +1,3 @@ -from typing import Any, List, Optional, Dict, Union, cast -import os -import sys -from pathlib import Path -from time import time -from multiprocessing import Manager, Queue, Process, get_context -from multiprocessing.context import SpawnProcess -from shutil import rmtree -from signal import SIGTERM -from subprocess import STDOUT, Popen, TimeoutExpired -from dataclasses import dataclass, field -from wasabi import msg -from wasabi.util import color, supports_ansi - -from .._util import check_rerun, check_deps, update_lockfile, load_project_config -from ...util import SimpleFrozenDict, working_dir, split_command, join_command -from ...util import check_bool_env_var, ENV_VARS -from ...errors import Errors - """ Permits the execution of a parallel command group. @@ -30,6 +11,25 @@ of the states is documented alongside the _ParallelCommandInfo.STATES code. Note between the states 'failed' and 'terminated' is not meaningful on Windows, so that both are displayed as 'failed/terminated' on Windows systems. """ +from typing import Any, List, Optional, Dict, Union, cast, Iterator +import os +import sys +import queue +from pathlib import Path +from time import time +from multiprocessing import Manager, Queue, get_context +from multiprocessing.context import SpawnProcess +from shutil import rmtree +from signal import SIGTERM +from subprocess import STDOUT, Popen, TimeoutExpired +from dataclasses import dataclass, field +from wasabi import msg +from wasabi.util import color, supports_ansi + +from .._util import check_rerun, check_deps, update_lockfile, load_project_config +from ...util import SimpleFrozenDict, working_dir, split_command, join_command +from ...util import check_bool_env_var, ENV_VARS +from ...errors import Errors # Use spawn to create worker processes on all OSs for consistency mp_context = get_context("spawn") @@ -145,7 +145,6 @@ def project_run_parallel_group( config = load_project_config(project_dir, overrides=overrides) commands = {cmd["name"]: cmd for cmd in config.get("commands", [])} check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION) - parallel_group_status_queue = Manager().Queue() max_parallel_processes = config.get("max_parallel_processes") cmd_infos = [ _ParallelCommandInfo(cmd_name, commands[cmd_name], cmd_index) @@ -153,6 +152,27 @@ def project_run_parallel_group( ] with working_dir(project_dir) as current_dir: + rmtree(PARALLEL_LOGGING_DIR_NAME, ignore_errors=True) + os.mkdir(PARALLEL_LOGGING_DIR_NAME) + + divider_parallel_descriptor = parallel_descriptor = ( + "parallel[" + ", ".join(cmd_info.cmd_name for cmd_info in cmd_infos) + "]" + ) + if len(divider_parallel_descriptor) > MAX_WIDTH_DIVIDER: + divider_parallel_descriptor = ( + divider_parallel_descriptor[: (MAX_WIDTH_DIVIDER - 3)] + "..." + ) + msg.divider(divider_parallel_descriptor) + if not DISPLAY_STATUS_TABLE and len(parallel_descriptor) > MAX_WIDTH_DIVIDER: + # reprint the descriptor if it was too long and had to be cut short + print(parallel_descriptor) + msg.info( + "Temporary logs are being written to " + + os.sep.join((str(current_dir), PARALLEL_LOGGING_DIR_NAME)) + ) + + parallel_group_status_queue = Manager().Queue() + for cmd_info in cmd_infos: check_deps(cmd_info.cmd, cmd_info.cmd_name, project_dir, dry) if ( @@ -162,8 +182,6 @@ def project_run_parallel_group( and not force ): cmd_info.change_state("not rerun") - rmtree(PARALLEL_LOGGING_DIR_NAME, ignore_errors=True) - os.mkdir(PARALLEL_LOGGING_DIR_NAME) worker_processes: List[SpawnProcess] = [] proc_to_cmd_infos: Dict[SpawnProcess, _ParallelCommandInfo] = {} num_concurr_worker_processes = 0 @@ -181,110 +199,25 @@ def project_run_parallel_group( ) worker_processes.append(worker_process) proc_to_cmd_infos[worker_process] = cmd_info + num_concurr_worker_processes = len(worker_processes) if ( max_parallel_processes is not None and max_parallel_processes < num_concurr_worker_processes ): num_concurr_worker_processes = max_parallel_processes + worker_process_iterator = iter(worker_processes) for _ in range(num_concurr_worker_processes): _start_worker_process(next(worker_process_iterator), proc_to_cmd_infos) - divider_parallel_descriptor = parallel_descriptor = ( - "parallel[" + ", ".join(cmd_info.cmd_name for cmd_info in cmd_infos) + "]" + _process_worker_status_messages( + cmd_infos, + proc_to_cmd_infos, + parallel_group_status_queue, + worker_process_iterator, + current_dir, + dry, ) - if len(divider_parallel_descriptor) > MAX_WIDTH_DIVIDER: - divider_parallel_descriptor = ( - divider_parallel_descriptor[: (MAX_WIDTH_DIVIDER - 3)] + "..." - ) - msg.divider(divider_parallel_descriptor) - if not DISPLAY_STATUS_TABLE and len(parallel_descriptor) > MAX_WIDTH_DIVIDER: - # reprint the descriptor if it was too long and had to be cut short - print(parallel_descriptor) - - msg.info( - "Temporary logs are being written to " - + os.sep.join((str(current_dir), PARALLEL_LOGGING_DIR_NAME)) - ) - status_table_not_yet_displayed = True - while any( - cmd_info.state.name in ("pending", "starting", "running") - for cmd_info in cmd_infos - ): - try: - mess: Dict[str, Union[str, int]] = parallel_group_status_queue.get( - timeout=PARALLEL_GROUP_STATUS_INTERVAL * 20 - ) - except Exception: - # No more messages are being received: the whole group has hung - for other_cmd_info in ( - c for c in cmd_infos if c.state.name in ("starting", "running") - ): - other_cmd_info.change_state("hung") - for other_cmd_info in ( - c for c in cmd_infos if c.state.name == "pending" - ): - other_cmd_info.change_state("cancelled") - break - cmd_info = cmd_infos[cast(int, mess["cmd_index"])] - if mess["type"] in ("started", "keepalive"): - cmd_info.last_keepalive_time = int(time()) - for other_cmd_info in ( - c for c in cmd_infos if c.state.name in ("starting", "running") - ): - if ( - other_cmd_info.last_keepalive_time is not None - and time() - other_cmd_info.last_keepalive_time - > PARALLEL_GROUP_STATUS_INTERVAL * 20 - ): - # a specific command has hung - other_cmd_info.change_state("hung") - if mess["type"] == "started": - cmd_info.change_state("running") - cmd_info.running_os_cmd_index = cast(int, mess["os_cmd_index"]) - cmd_info.pid = cast(int, mess["pid"]) - if mess["type"] == "completed": - cmd_info.rc = cast(int, mess["rc"]) - if cmd_info.rc == 0: - cmd_info.change_state("succeeded") - if not dry: - update_lockfile(current_dir, cmd_info.cmd) - next_worker_process = next(worker_process_iterator, None) - if next_worker_process is not None: - _start_worker_process(next_worker_process, proc_to_cmd_infos) - elif cmd_info.rc > 0: - cmd_info.change_state("failed") - else: - cmd_info.change_state("terminated") - cmd_info.console_output = cast(str, mess["console_output"]) - if any( - c for c in cmd_infos if c.state.name in ("failed", "terminated", "hung") - ): - # a command in the group hasn't succeeded, so terminate/cancel the rest - for other_cmd_info in ( - c for c in cmd_infos if c.state.name == "running" - ): - try: - os.kill(cast(int, other_cmd_info.pid), SIGTERM) - except: - # the subprocess the main process is trying to kill could already - # have completed, and the message from the worker process notifying - # the main process about this could still be in the queue - pass - for other_cmd_info in ( - c for c in cmd_infos if c.state.name == "pending" - ): - other_cmd_info.change_state("cancelled") - if mess["type"] != "keepalive" and DISPLAY_STATUS_TABLE: - if status_table_not_yet_displayed: - status_table_not_yet_displayed = False - else: - # overwrite the existing status table - print("\033[2K\033[F" * (4 + len(cmd_infos))) - data = [[c.cmd_name, c.state_repr] for c in cmd_infos] - header = ["Command", "Status"] - msg.table(data, header=header) - for cmd_info in (c for c in cmd_infos if c.state.name != "cancelled"): msg.divider(cmd_info.cmd_name) if cmd_info.state.name == "not rerun": @@ -305,6 +238,98 @@ def project_run_parallel_group( sys.exit(-1) +def _process_worker_status_messages( + cmd_infos: List[_ParallelCommandInfo], + proc_to_cmd_infos: Dict[SpawnProcess, _ParallelCommandInfo], + parallel_group_status_queue: queue.Queue, + worker_process_iterator: Iterator[SpawnProcess], + current_dir: Path, + dry: bool, +) -> None: + """Listens on the status queue and processes messages received from the worker processes. + + cmd_infos: a list of info objects about the commands in the parallel group + proc_to_cmd_infos: a dictionary from Process objects to command info objects + parallel_group_status_queue: the status queue + worker_process_iterator: an iterator over the processes, some or all of which + will already have been iterated over and started + current_dir: the current directory + dry (bool): Perform a dry run and don't execute commands. + """ + status_table_not_yet_displayed = True + while any( + cmd_info.state.name in ("pending", "starting", "running") + for cmd_info in cmd_infos + ): + try: + mess: Dict[str, Union[str, int]] = parallel_group_status_queue.get( + timeout=PARALLEL_GROUP_STATUS_INTERVAL * 20 + ) + except Exception: + # No more messages are being received: the whole group has hung + for other_cmd_info in ( + c for c in cmd_infos if c.state.name in ("starting", "running") + ): + other_cmd_info.change_state("hung") + for other_cmd_info in (c for c in cmd_infos if c.state.name == "pending"): + other_cmd_info.change_state("cancelled") + break + cmd_info = cmd_infos[cast(int, mess["cmd_index"])] + if mess["mess_type"] in ("started", "keepalive"): + cmd_info.last_keepalive_time = int(time()) + for other_cmd_info in ( + c for c in cmd_infos if c.state.name in ("starting", "running") + ): + if ( + other_cmd_info.last_keepalive_time is not None + and time() - other_cmd_info.last_keepalive_time + > PARALLEL_GROUP_STATUS_INTERVAL * 20 + ): + # a specific command has hung + other_cmd_info.change_state("hung") + if mess["mess_type"] == "started": + cmd_info.change_state("running") + cmd_info.running_os_cmd_index = cast(int, mess["os_cmd_index"]) + cmd_info.pid = cast(int, mess["pid"]) + if mess["mess_type"] == "completed": + cmd_info.rc = cast(int, mess["rc"]) + if cmd_info.rc == 0: + cmd_info.change_state("succeeded") + if not dry: + update_lockfile(current_dir, cmd_info.cmd) + next_worker_process = next(worker_process_iterator, None) + if next_worker_process is not None: + _start_worker_process(next_worker_process, proc_to_cmd_infos) + elif cmd_info.rc > 0: + cmd_info.change_state("failed") + else: + cmd_info.change_state("terminated") + cmd_info.console_output = cast(str, mess["console_output"]) + if any( + c for c in cmd_infos if c.state.name in ("failed", "terminated", "hung") + ): + # a command in the group hasn't succeeded, so terminate/cancel the rest + for other_cmd_info in (c for c in cmd_infos if c.state.name == "running"): + try: + os.kill(cast(int, other_cmd_info.pid), SIGTERM) + except: + # the subprocess the main process is trying to kill could already + # have completed, and the message from the worker process notifying + # the main process about this could still be in the queue + pass + for other_cmd_info in (c for c in cmd_infos if c.state.name == "pending"): + other_cmd_info.change_state("cancelled") + if mess["mess_type"] != "keepalive" and DISPLAY_STATUS_TABLE: + if status_table_not_yet_displayed: + status_table_not_yet_displayed = False + else: + # overwrite the existing status table + print("\033[2K\033[F" * (4 + len(cmd_infos))) + data = [[c.cmd_name, c.state_repr] for c in cmd_infos] + header = ["Command", "Status"] + msg.table(data, header=header) + + def _project_run_parallel_cmd( cmd_info: _ParallelCommandInfo, *, @@ -314,9 +339,9 @@ def _project_run_parallel_cmd( ) -> None: """Run a single spaCy projects command as a worker process. - Communicates with the main process via queue messages that consist of - dictionaries whose type is determined by the entry 'type'. Possible - values of 'type' are 'started', 'completed' and 'keepalive'. Each dictionary + Communicates with the main process via queue messages whose type is determined + by the entry 'mess_type' and that are structured as dictionaries. Possible + values of 'mess_type' are 'started', 'completed' and 'keepalive'. Each dictionary type contains different additional fields.""" # we can use the command name as a unique log filename because a parallel # group is not allowed to contain the same command more than once @@ -331,6 +356,8 @@ def _project_run_parallel_cmd( if len(command) and command[0] in ("python", "python3"): # -u: prevent buffering within Python command = [sys.executable, "-u", *command[1:]] + elif len(command) and command[0].startswith("python3"): # e.g. python3.10 + command = [command[0], "-u", *command[1:]] elif len(command) and command[0] in ("pip", "pip3"): command = [sys.executable, "-m", "pip", *command[1:]] logfile.write( @@ -364,7 +391,7 @@ def _project_run_parallel_cmd( break parallel_group_status_queue.put( { - "type": "started", + "mess_type": "started", "cmd_index": cmd_info.cmd_index, "os_cmd_index": os_cmd_index, "pid": sp.pid, @@ -378,7 +405,7 @@ def _project_run_parallel_cmd( if sp.returncode == None: parallel_group_status_queue.put( { - "type": "keepalive", + "mess_type": "keepalive", "cmd_index": cmd_info.cmd_index, } ) @@ -403,7 +430,7 @@ def _project_run_parallel_cmd( else: # dry run parallel_group_status_queue.put( { - "type": "started", + "mess_type": "started", "cmd_index": cmd_info.cmd_index, "os_cmd_index": os_cmd_index, "pid": -1, @@ -419,7 +446,7 @@ def _project_run_parallel_cmd( rc = sp.returncode parallel_group_status_queue.put( { - "type": "completed", + "mess_type": "completed", "cmd_index": cmd_info.cmd_index, "rc": rc, "console_output": logfile.read(), diff --git a/spacy/tests/test_parallel.py b/spacy/tests/test_parallel.py index 5fadebb88..3d15cb204 100644 --- a/spacy/tests/test_parallel.py +++ b/spacy/tests/test_parallel.py @@ -8,6 +8,7 @@ from .util import make_tempdir def test_project_config_multiprocessing_good_case(): + """Confirms that a correct parallel command group is loaded from the project file.""" project = { "commands": [ {"name": "command1", "script": ["echo", "command1"]}, @@ -31,6 +32,14 @@ def test_project_config_multiprocessing_good_case(): ], ) def test_project_config_multiprocessing_bad_case_workflows(workflows): + """Confirms that parallel command groups that do not obey various rules + result in errors: + + - a parallel group with only one command + - a parallel group with a non-existent command + - a parallel group containing the same command more than once + - a parallel group whose dictionary name is not 'parallel' + """ project = { "commands": [ {"name": "command1", "script": ["echo", "command1"]}, @@ -47,6 +56,13 @@ def test_project_config_multiprocessing_bad_case_workflows(workflows): @pytest.mark.parametrize("max_parallel_processes", [-1, 0, 1]) def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_processes): + """Confirms that invalid max_parallel_processes values are not accepted: + + - negative value + - zero + - 1 (because the commands in the group can then be executed serially) + """ + with make_tempdir() as d: project = { "max_parallel_processes": max_parallel_processes, @@ -73,6 +89,13 @@ def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_proc def test_project_run_multiprocessing_good_case(): + """Executes a parallel command group. The script pscript.py writes to a file, then polls until + a second file exists. It is executed twice by commands in a parallel group that write + and read opposite files, thus proving that the commands are being executed simultaneously. It + then checks if a file f exists, writes to f if f does not already exist and writes to a + second file if f does already exist. Although the workflow is run twice, the fact that both + commands should be skipped on the second run mean that these second files should never be written. + """ with make_tempdir() as d: pscript = """ @@ -106,6 +129,7 @@ else: # should never happen because of skipping " ".join(("python", pscript_loc, str(d), "a", "b", "c", "d")) ], "outputs": [os.sep.join((str(d), "f"))], + # output should work even if file doesn't exist }, { "name": "commandB", @@ -133,6 +157,9 @@ else: # should never happen because of skipping @pytest.mark.parametrize("max_parallel_processes", [2, 3, 4, 5, 6]) def test_project_run_multiprocessing_max_processes_good_case(max_parallel_processes): + """Confirms that no errors are thrown when a parallel group is executed + with various values of max_parallel_processes. + """ with make_tempdir() as d: project = { @@ -185,6 +212,12 @@ def test_project_run_multiprocessing_max_processes_good_case(max_parallel_proces @pytest.mark.parametrize("failing_command", ["python", "someNotExistentOSCommand"]) def test_project_run_multiprocessing_failure(failing_command: str): + """Confirms that when one command in a parallel group fails, the other commands + are terminated. The script has two arguments: the number of seconds for which + to sleep, and the return code. One command in a group is terminated immediately + with a non-zero return code, the other two commands after several seconds with + zero return codes. Measuring the execution length for the whole group shows + whether or not the sleeping processes were successfully terminated.""" with make_tempdir() as d: pscript = """ @@ -232,7 +265,7 @@ sys.exit(int(rc)) # 15 is the highest rc rather than 1. assert rc_e.value.code == 15 else: - assert rc_e.value.code == 1 + assert rc_e.value.code == 1 assert ( time() - start < 5 ), "Test took too long, subprocess seems not to have been terminated"