From 4c2fc56a5b796264455ebe125677b99da9fd64e3 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Tue, 19 Jul 2022 14:08:46 +0200 Subject: [PATCH] Refactoring into separate module --- spacy/cli/_util.py | 122 +++++++++++++++++++++++ spacy/cli/project/parallel.py | 74 ++++++++------ spacy/cli/project/pull.py | 6 +- spacy/cli/project/run.py | 171 ++----------------------------- spacy/tests/test_cli.py | 165 ------------------------------ spacy/tests/test_parallel.py | 183 ++++++++++++++++++++++++++++++++++ 6 files changed, 359 insertions(+), 362 deletions(-) create mode 100644 spacy/tests/test_parallel.py diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index 3a64e3211..90b03d5b1 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -18,8 +18,12 @@ import os from ..compat import Literal from ..schemas import ProjectConfigSchema, validate +from ..git_info import GIT_VERSION from ..util import import_file, run_command, make_tempdir, registry, logger from ..util import is_compatible_version, SimpleFrozenDict, ENV_VARS +from ..util import is_minor_version_match + + from .. import about if TYPE_CHECKING: @@ -598,3 +602,121 @@ def setup_gpu(use_gpu: int, silent=None) -> None: local_msg.info("Using CPU") if has_cupy and gpu_is_available(): local_msg.info("To switch to GPU 0, use the option: --gpu-id 0") + + +def check_rerun( + project_dir: Path, + command: Dict[str, Any], + *, + check_spacy_version: bool = True, + check_spacy_commit: bool = False, +) -> bool: + """Check if a command should be rerun because its settings or inputs/outputs + changed. + + project_dir (Path): The current project directory. + command (Dict[str, Any]): The command, as defined in the project.yml. + strict_version (bool): + RETURNS (bool): Whether to re-run the command. + """ + # Always rerun if no-skip is set + if command.get("no_skip", False): + return True + lock_path = project_dir / PROJECT_LOCK + if not lock_path.exists(): # We don't have a lockfile, run command + return True + data = srsly.read_yaml(lock_path) + if command["name"] not in data: # We don't have info about this command + return True + entry = data[command["name"]] + # Always run commands with no outputs (otherwise they'd always be skipped) + if not entry.get("outs", []): + return True + # Always rerun if spaCy version or commit hash changed + spacy_v = entry.get("spacy_version") + commit = entry.get("spacy_git_version") + if check_spacy_version and not is_minor_version_match(spacy_v, about.__version__): + info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)" + msg.info(f"Re-running '{command['name']}': spaCy minor version changed {info}") + return True + if check_spacy_commit and commit != GIT_VERSION: + info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)" + msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}") + return True + # If the entry in the lockfile matches the lockfile entry that would be + # generated from the current command, we don't rerun because it means that + # all inputs/outputs, hashes and scripts are the same and nothing changed + lock_entry = _get_lock_entry(project_dir, command) + exclude = ["spacy_version", "spacy_git_version"] + return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude) + + +def update_lockfile( + project_dir: Path, + command: Dict[str, Any], +) -> None: + """Update the lockfile after running a command. Will create a lockfile if + it doesn't yet exist and will add an entry for the current command, its + script and dependencies/outputs. + + project_dir (Path): The current project directory. + command (Dict[str, Any]): The command, as defined in the project.yml. + """ + lock_path = project_dir / PROJECT_LOCK + if not lock_path.exists(): + srsly.write_yaml(lock_path, {}) + data = {} + else: + data = srsly.read_yaml(lock_path) + data[command["name"]] = _get_lock_entry(project_dir, command) + srsly.write_yaml(lock_path, data) + + +def _get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]: + """Get a lockfile entry for a given command. An entry includes the command, + the script (command steps) and a list of dependencies and outputs with + their paths and file hashes, if available. The format is based on the + dvc.lock files, to keep things consistent. + + project_dir (Path): The current project directory. + command (Dict[str, Any]): The command, as defined in the project.yml. + RETURNS (Dict[str, Any]): The lockfile entry. + """ + deps = _get_fileinfo(project_dir, command.get("deps", [])) + outs = _get_fileinfo(project_dir, command.get("outputs", [])) + outs_nc = _get_fileinfo(project_dir, command.get("outputs_no_cache", [])) + return { + "cmd": f"{COMMAND} run {command['name']}", + "script": command["script"], + "deps": deps, + "outs": [*outs, *outs_nc], + "spacy_version": about.__version__, + "spacy_git_version": GIT_VERSION, + } + + +def _get_fileinfo( + project_dir: Path, paths: List[str] +) -> List[Dict[str, Optional[str]]]: + """Generate the file information for a list of paths (dependencies, outputs). + Includes the file path and the file's checksum. + + project_dir (Path): The current project directory. + paths (List[str]): The file paths. + RETURNS (List[Dict[str, str]]): The lockfile entry for a file. + """ + data = [] + for path in paths: + file_path = project_dir / path + md5 = get_checksum(file_path) if file_path.exists() else None + data.append({"path": path, "md5": md5}) + return data + + +def check_deps(cmd: Dict, cmd_name: str, project_dir: Path, dry: bool): + for dep in cmd.get("deps", []): + if not (project_dir / dep).exists(): + err = f"Missing dependency specified by command '{cmd_name}': {dep}" + err_help = "Maybe you forgot to run the 'project assets' command or a previous step?" + err_kwargs = {"exits": 1} if not dry else {} + msg.fail(err, err_help, **err_kwargs) diff --git a/spacy/cli/project/parallel.py b/spacy/cli/project/parallel.py index 98338523b..49de62168 100644 --- a/spacy/cli/project/parallel.py +++ b/spacy/cli/project/parallel.py @@ -1,4 +1,5 @@ -from typing import Any, List, Literal, Optional, Dict +from typing import Any, List, Literal, Optional, Dict, Union, cast +import sys from pathlib import Path from time import time import multiprocessing @@ -7,14 +8,15 @@ from os import kill, environ, linesep, mkdir, sep from shutil import rmtree from signal import SIGTERM from subprocess import STDOUT, Popen, TimeoutExpired +from wasabi import msg from wasabi.util import color, supports_ansi -from ...util import SimpleFrozenDict, load_project_config, working_dir -from ...util import check_bool_env_var +from .._util import check_rerun, check_deps, update_lockfile, load_project_config +from ...util import SimpleFrozenDict, working_dir +from ...util import check_bool_env_var, ENV_VARS, split_command, join_command from ...errors import Errors - # How often the worker processes managing the commands in a parallel group # send keepalive messages to the main processes PARALLEL_GROUP_STATUS_INTERVAL = 1 @@ -34,8 +36,10 @@ DISPLAY_STATUS_COLORS = { "hung": "red", "cancelled": "red", } -class ParallelCommandInfo: - def __init__(self, cmd_name: str, cmd: str, cmd_ind: int): + + +class _ParallelCommandInfo: + def __init__(self, cmd_name: str, cmd: Dict, cmd_ind: int): self.cmd_name = cmd_name self.cmd = cmd self.cmd_ind = cmd_ind @@ -59,21 +63,26 @@ class ParallelCommandInfo: @property def disp_status(self) -> str: - status = self.status - status_color = DISPLAY_STATUS_COLORS[status] - if status == "running": - status = f"{status} ({self.os_cmd_ind + 1}/{self.len_os_cmds})" - return color(status, status_color) + status_str = str(self.status) + status_color = DISPLAY_STATUS_COLORS[status_str] + if status_str == "running" and self.os_cmd_ind is not None: + status_str = f"{status_str} ({self.os_cmd_ind + 1}/{self.len_os_cmds})" + return color(status_str, status_color) -def start_process(process:Process, proc_to_cmd_infos: Dict[Process, ParallelCommandInfo]) -> None: + +def _start_process( + process: Process, proc_to_cmd_infos: Dict[Process, _ParallelCommandInfo] +) -> None: cmd_info = proc_to_cmd_infos[process] if cmd_info.status == "pending": cmd_info.status = "starting" + cmd_info.last_status_time = int(time()) process.start() + def project_run_parallel_group( project_dir: Path, - cmd_infos: List[ParallelCommandInfo], + cmd_names: List[str], *, overrides: Dict[str, Any] = SimpleFrozenDict(), force: bool = False, @@ -90,11 +99,18 @@ def project_run_parallel_group( dry (bool): Perform a dry run and don't execute commands. """ config = load_project_config(project_dir, overrides=overrides) + commands = {cmd["name"]: cmd for cmd in config.get("commands", [])} + parallel_group_status_queue = MultiprocessingManager().Queue() max_parallel_processes = config.get("max_parallel_processes") check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION) multiprocessing.set_start_method("spawn", force=True) DISPLAY_STATUS = sys.stdout.isatty() and supports_ansi() + cmd_infos = [ + _ParallelCommandInfo(cmd_name, commands[cmd_name], cmd_ind) + for cmd_ind, cmd_name in enumerate(cmd_names) + ] + with working_dir(project_dir) as current_dir: for cmd_info in cmd_infos: check_deps(cmd_info.cmd, cmd_info.cmd_name, project_dir, dry) @@ -107,14 +123,14 @@ def project_run_parallel_group( cmd_info.status = "not rerun" rmtree(PARALLEL_LOGGING_DIR_NAME, ignore_errors=True) mkdir(PARALLEL_LOGGING_DIR_NAME) - processes : List[Process] = [] - proc_to_cmd_infos: Dict[Process, ParallelCommandInfo] = {} + processes: List[Process] = [] + proc_to_cmd_infos: Dict[Process, _ParallelCommandInfo] = {} num_processes = 0 for cmd_info in cmd_infos: if cmd_info.status == "not rerun": continue process = Process( - target=project_run_parallel_cmd, + target=_project_run_parallel_cmd, args=(cmd_info,), kwargs={ "dry": dry, @@ -132,7 +148,7 @@ def project_run_parallel_group( num_processes = max_parallel_processes process_iterator = iter(processes) for _ in range(num_processes): - start_process(next(process_iterator), proc_to_cmd_infos) + _start_process(next(process_iterator), proc_to_cmd_infos) divider_parallel_descriptor = parallel_descriptor = ( "parallel[" + ", ".join(cmd_info.cmd_name for cmd_info in cmd_infos) + "]" ) @@ -157,10 +173,12 @@ def project_run_parallel_group( other_cmd_info.status = "hung" for other_cmd_info in (c for c in cmd_infos if c.status == "pending"): other_cmd_info.status = "cancelled" - cmd_info = cmd_infos[mess["cmd_ind"]] + cmd_info = cmd_infos[cast(int, mess["cmd_ind"])] if mess["status"] in ("started", "alive"): cmd_info.last_status_time = int(time()) - for other_cmd_info in (c for c in cmd_infos if c.status == "running"): + for other_cmd_info in ( + c for c in cmd_infos if c.status in ("starting", "running") + ): if ( other_cmd_info.last_status_time is not None and time() - other_cmd_info.last_status_time @@ -169,26 +187,26 @@ def project_run_parallel_group( other_cmd_info.status = "hung" if mess["status"] == "started": cmd_info.status = "running" - cmd_info.os_cmd_ind = mess["os_cmd_ind"] - cmd_info.pid = mess["pid"] + cmd_info.os_cmd_ind = cast(int, mess["os_cmd_ind"]) + cmd_info.pid = cast(int, mess["pid"]) if mess["status"] == "completed": - cmd_info.rc = mess["rc"] + cmd_info.rc = cast(int, mess["rc"]) if cmd_info.rc == 0: cmd_info.status = "succeeded" if not dry: update_lockfile(current_dir, cmd_info.cmd) working_process = next(process_iterator, None) if working_process is not None: - start_process(working_process, proc_to_cmd_infos) + _start_process(working_process, proc_to_cmd_infos) elif cmd_info.rc > 0: cmd_info.status = "failed" else: cmd_info.status = "killed" - cmd_info.output = mess["output"] + cmd_info.output = cast(str, mess["output"]) if any(c for c in cmd_infos if c.status in ("failed", "killed", "hung")): for other_cmd_info in (c for c in cmd_infos if c.status == "running"): try: - kill(other_cmd_info.pid, SIGTERM) + kill(cast(int, other_cmd_info.pid), SIGTERM) except: pass for other_cmd_info in (c for c in cmd_infos if c.status == "pending"): @@ -219,8 +237,8 @@ def project_run_parallel_group( sys.exit(-1) -def project_run_parallel_cmd( - cmd_info: ParallelCommandInfo, +def _project_run_parallel_cmd( + cmd_info: _ParallelCommandInfo, *, dry: bool, current_dir: str, @@ -318,4 +336,4 @@ def project_run_parallel_cmd( "rc": rc, "output": logfile.read(), } - ) \ No newline at end of file + ) diff --git a/spacy/cli/project/pull.py b/spacy/cli/project/pull.py index 9496cdb1a..d225eeaca 100644 --- a/spacy/cli/project/pull.py +++ b/spacy/cli/project/pull.py @@ -4,8 +4,7 @@ from wasabi import msg from .remote_storage import RemoteStorage from .remote_storage import get_command_hash from .._util import project_cli, Arg, logger -from .._util import load_project_config -from .run import update_lockfile +from .._util import load_project_config, update_lockfile @project_cli.command("pull") @@ -38,7 +37,6 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False): # We use a while loop here because we don't know how the commands # will be ordered. A command might need dependencies from one that's later # in the list. - mult_group_mutex = multiprocessing.Lock() while commands: for i, cmd in enumerate(list(commands)): logger.debug(f"CMD: {cmd['name']}.") @@ -54,7 +52,7 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False): out_locs = [project_dir / out for out in cmd.get("outputs", [])] if all(loc.exists() for loc in out_locs): - update_lockfile(project_dir, cmd, mult_group_mutex=mult_group_mutex) + update_lockfile(project_dir, cmd) # We remove the command from the list here, and break, so that # we iterate over the loop again. commands.pop(i) diff --git a/spacy/cli/project/run.py b/spacy/cli/project/run.py index 2fa071d36..52a21d97e 100644 --- a/spacy/cli/project/run.py +++ b/spacy/cli/project/run.py @@ -3,16 +3,15 @@ import sys from pathlib import Path from wasabi import msg from wasabi.util import locale_escape -import srsly import typer -from ... import about -from ...git_info import GIT_VERSION + +from .parallel import project_run_parallel_group +from .._util import project_cli, Arg, Opt, COMMAND, parse_config_overrides, check_deps +from .._util import PROJECT_FILE, load_project_config, check_rerun, update_lockfile from ...util import working_dir, run_command, split_command, is_cwd, join_command -from ...util import SimpleFrozenList, is_minor_version_match, ENV_VARS +from ...util import SimpleFrozenList, ENV_VARS from ...util import check_bool_env_var, SimpleFrozenDict -from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash -from .._util import get_checksum, project_cli, Arg, Opt, COMMAND, parse_config_overrides @project_cli.command( @@ -42,15 +41,6 @@ def project_run_cli( project_run(project_dir, subcommand, overrides=overrides, force=force, dry=dry) -def check_deps(cmd: Dict, cmd_name: str, project_dir: Path, dry: bool): - for dep in cmd.get("deps", []): - if not (project_dir / dep).exists(): - err = f"Missing dependency specified by command '{cmd_name}': {dep}" - err_help = "Maybe you forgot to run the 'project assets' command or a previous step?" - err_kwargs = {"exits": 1} if not dry else {} - msg.fail(err, err_help, **err_kwargs) - - def project_run( project_dir: Path, subcommand: str, @@ -99,10 +89,7 @@ def project_run( assert isinstance(cmds[0], str) project_run_parallel_group( project_dir, - [ - ParallelCommandInfo(cmd, commands[cmd], cmd_ind) - for cmd_ind, cmd in enumerate(cmds) - ], + cmds, overrides=overrides, force=force, dry=dry, @@ -130,9 +117,6 @@ def project_run( update_lockfile(current_dir, cmd) - - - def run_commands( commands: Iterable[str] = SimpleFrozenList(), silent: bool = False, @@ -250,42 +234,6 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None: msg.table(table_entries) -def run_commands( - commands: Iterable[str] = SimpleFrozenList(), - silent: bool = False, - dry: bool = False, - capture: bool = False, -) -> None: - """Run a sequence of commands in a subprocess, in order. - - commands (List[str]): The string commands. - silent (bool): Don't print the commands. - dry (bool): Perform a dry run and don't execute anything. - capture (bool): Whether to capture the output and errors of individual commands. - If False, the stdout and stderr will not be redirected, and if there's an error, - sys.exit will be called with the return code. You should use capture=False - when you want to turn over execution to the command, and capture=True - when you want to run the command more like a function. - """ - for c in commands: - command = split_command(c) - # Not sure if this is needed or a good idea. Motivation: users may often - # use commands in their config that reference "python" and we want to - # make sure that it's always executing the same Python that spaCy is - # executed with and the pip in the same env, not some other Python/pip. - # Also ensures cross-compatibility if user 1 writes "python3" (because - # that's how it's set up on their system), and user 2 without the - # shortcut tries to re-run the command. - if len(command) and command[0] in ("python", "python3"): - command[0] = sys.executable - elif len(command) and command[0] in ("pip", "pip3"): - command = [sys.executable, "-m", "pip", *command[1:]] - if not silent: - print(f"Running command: {join_command(command)}") - if not dry: - run_command(command, capture=capture) - - def validate_subcommand( commands: Sequence[str], workflows: Sequence[str], subcommand: str ) -> None: @@ -307,110 +255,3 @@ def validate_subcommand( ". ".join(help_msg), exits=1, ) - - -def check_rerun( - project_dir: Path, - command: Dict[str, Any], - *, - check_spacy_version: bool = True, - check_spacy_commit: bool = False, -) -> bool: - """Check if a command should be rerun because its settings or inputs/outputs - changed. - - project_dir (Path): The current project directory. - command (Dict[str, Any]): The command, as defined in the project.yml. - strict_version (bool): - RETURNS (bool): Whether to re-run the command. - """ - # Always rerun if no-skip is set - if command.get("no_skip", False): - return True - lock_path = project_dir / PROJECT_LOCK - if not lock_path.exists(): # We don't have a lockfile, run command - return True - data = srsly.read_yaml(lock_path) - if command["name"] not in data: # We don't have info about this command - return True - entry = data[command["name"]] - # Always run commands with no outputs (otherwise they'd always be skipped) - if not entry.get("outs", []): - return True - # Always rerun if spaCy version or commit hash changed - spacy_v = entry.get("spacy_version") - commit = entry.get("spacy_git_version") - if check_spacy_version and not is_minor_version_match(spacy_v, about.__version__): - info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)" - msg.info(f"Re-running '{command['name']}': spaCy minor version changed {info}") - return True - if check_spacy_commit and commit != GIT_VERSION: - info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)" - msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}") - return True - # If the entry in the lockfile matches the lockfile entry that would be - # generated from the current command, we don't rerun because it means that - # all inputs/outputs, hashes and scripts are the same and nothing changed - lock_entry = get_lock_entry(project_dir, command) - exclude = ["spacy_version", "spacy_git_version"] - return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude) - - -def update_lockfile( - project_dir: Path, - command: Dict[str, Any], -) -> None: - """Update the lockfile after running a command. Will create a lockfile if - it doesn't yet exist and will add an entry for the current command, its - script and dependencies/outputs. - - project_dir (Path): The current project directory. - command (Dict[str, Any]): The command, as defined in the project.yml. - """ - lock_path = project_dir / PROJECT_LOCK - if not lock_path.exists(): - srsly.write_yaml(lock_path, {}) - data = {} - else: - data = srsly.read_yaml(lock_path) - data[command["name"]] = get_lock_entry(project_dir, command) - srsly.write_yaml(lock_path, data) - - -def get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]: - """Get a lockfile entry for a given command. An entry includes the command, - the script (command steps) and a list of dependencies and outputs with - their paths and file hashes, if available. The format is based on the - dvc.lock files, to keep things consistent. - - project_dir (Path): The current project directory. - command (Dict[str, Any]): The command, as defined in the project.yml. - RETURNS (Dict[str, Any]): The lockfile entry. - """ - deps = get_fileinfo(project_dir, command.get("deps", [])) - outs = get_fileinfo(project_dir, command.get("outputs", [])) - outs_nc = get_fileinfo(project_dir, command.get("outputs_no_cache", [])) - return { - "cmd": f"{COMMAND} run {command['name']}", - "script": command["script"], - "deps": deps, - "outs": [*outs, *outs_nc], - "spacy_version": about.__version__, - "spacy_git_version": GIT_VERSION, - } - - -def get_fileinfo(project_dir: Path, paths: List[str]) -> List[Dict[str, Optional[str]]]: - """Generate the file information for a list of paths (dependencies, outputs). - Includes the file path and the file's checksum. - - project_dir (Path): The current project directory. - paths (List[str]): The file paths. - RETURNS (List[Dict[str, str]]): The lockfile entry for a file. - """ - data = [] - for path in paths: - file_path = project_dir / path - md5 = get_checksum(file_path) if file_path.exists() else None - data.append({"path": path, "md5": md5}) - return data diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 50ba5cecc..0fa6f5670 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -19,7 +19,6 @@ from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config from spacy.cli.package import get_third_party_dependencies from spacy.cli.package import _is_permitted_package_name from spacy.cli.validate import get_model_pkgs -from spacy.cli.project.run import project_run from spacy.lang.en import English from spacy.lang.nl import Dutch from spacy.language import Language @@ -422,170 +421,6 @@ def test_project_config_interpolation(int_value): substitute_project_variables(project) -def test_project_config_multiprocessing_good_case(): - project = { - "commands": [ - {"name": "command1", "script": ["echo", "command1"]}, - {"name": "command2", "script": ["echo", "command2"]}, - {"name": "command3", "script": ["echo", "command3"]}, - ], - "workflows": {"all": ["command1", {"parallel": ["command2", "command3"]}]}, - } - with make_tempdir() as d: - srsly.write_yaml(d / "project.yml", project) - load_project_config(d) - - -@pytest.mark.parametrize( - "workflows", - [ - {"all": ["command1", {"parallel": ["command2"]}, "command3"]}, - {"all": ["command1", {"parallel": ["command2", "command4"]}]}, - {"all": ["command1", {"parallel": ["command2", "command2"]}]}, - {"all": ["command1", {"serial": ["command2", "command3"]}]}, - ], -) -def test_project_config_multiprocessing_bad_case_workflows(workflows): - project = { - "commands": [ - {"name": "command1", "script": ["echo", "command1"]}, - {"name": "command2", "script": ["echo", "command2"]}, - {"name": "command3", "script": ["echo", "command3"]}, - ], - "workflows": workflows, - } - with make_tempdir() as d: - srsly.write_yaml(d / "project.yml", project) - with pytest.raises(SystemExit): - load_project_config(d) - - -@pytest.mark.parametrize("max_parallel_processes", [-1, 0, 1]) -def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_processes): - with make_tempdir() as d: - project = { - "max_parallel_processes": max_parallel_processes, - "commands": [ - { - "name": "commandA", - "script": [" ".join(("touch", os.sep.join((str(d), "A"))))], - }, - { - "name": "commandB", - "script": [" ".join(("touch", os.sep.join((str(d), "B"))))], - }, - { - "name": "commandC", - "script": [" ".join(("touch", os.sep.join((str(d), "C"))))], - }, - ], - "workflows": {"all": [{"parallel": ["commandA", "commandB", "commandC"]}]}, - } - with make_tempdir() as d: - srsly.write_yaml(d / "project.yml", project) - with pytest.raises(SystemExit): - load_project_config(d) - - -def test_project_run_multiprocessing_good_case(): - with make_tempdir() as d: - - pscript = """ -import sys, os -from time import sleep - -_, d_path, in_filename, out_filename, first_flag_filename, second_flag_filename = sys.argv -with open(os.sep.join((d_path, out_filename)), 'w') as out_file: - out_file.write("") -while True: - if os.path.exists(os.sep.join((d_path, in_filename))): - break - sleep(0.1) -if not os.path.exists(os.sep.join((d_path, first_flag_filename))): - with open(os.sep.join((d_path, first_flag_filename)), 'w') as first_flag_file: - first_flag_file.write("") -else: # should never happen because of skipping - with open(os.sep.join((d_path, second_flag_filename)), 'w') as second_flag_file: - second_flag_file.write("") - """ - - pscript_loc = os.sep.join((str(d), "pscript.py")) - with open(pscript_loc, "w") as pscript_file: - pscript_file.write(pscript) - os.chmod(pscript_loc, 0o777) - project = { - "commands": [ - { - "name": "commandA", - "script": [ - " ".join(("python", pscript_loc, str(d), "a", "b", "c", "d")) - ], - "outputs": [os.sep.join((str(d), "f"))], - }, - { - "name": "commandB", - "script": [ - " ".join(("python", pscript_loc, str(d), "b", "a", "e", "f")) - ], - "outputs": [os.sep.join((str(d), "e"))], - }, - ], - "workflows": { - "all": [ - {"parallel": ["commandA", "commandB"]}, - {"parallel": ["commandB", "commandA"]}, - ] - }, - } - srsly.write_yaml(d / "project.yml", project) - load_project_config(d) - project_run(d, "all") - assert os.path.exists(os.sep.join((str(d), "c"))) - assert os.path.exists(os.sep.join((str(d), "e"))) - assert not os.path.exists(os.sep.join((str(d), "d"))) - assert not os.path.exists(os.sep.join((str(d), "f"))) - - -@pytest.mark.parametrize("max_parallel_processes", [2, 3, 4, 5, 6]) -def test_project_run_multiprocessing_max_processes_good_case(max_parallel_processes): - with make_tempdir() as d: - - project = { - "max_parallel_processes": max_parallel_processes, - "commands": [ - { - "name": "commandA", - "script": [" ".join(("touch", os.sep.join((str(d), "A"))))], - }, - { - "name": "commandB", - "script": [" ".join(("touch", os.sep.join((str(d), "B"))))], - }, - { - "name": "commandC", - "script": [" ".join(("touch", os.sep.join((str(d), "C"))))], - }, - { - "name": "commandD", - "script": [" ".join(("touch", os.sep.join((str(d), "D"))))], - }, - { - "name": "commandE", - "script": [" ".join(("touch", os.sep.join((str(d), "E"))))], - }, - ], - "workflows": {"all": [{"parallel": ["commandA", "commandB", "commandC", "commandD", "commandE"]}]}, - } - srsly.write_yaml(d / "project.yml", project) - load_project_config(d) - project_run(d, "all") - assert os.path.exists(os.sep.join((str(d), "A"))) - assert os.path.exists(os.sep.join((str(d), "B"))) - assert os.path.exists(os.sep.join((str(d), "C"))) - assert os.path.exists(os.sep.join((str(d), "D"))) - assert os.path.exists(os.sep.join((str(d), "E"))) - - @pytest.mark.parametrize( "greeting", [342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)], diff --git a/spacy/tests/test_parallel.py b/spacy/tests/test_parallel.py new file mode 100644 index 000000000..68bd1a027 --- /dev/null +++ b/spacy/tests/test_parallel.py @@ -0,0 +1,183 @@ +import os + +import pytest +import srsly +from spacy.cli._util import load_project_config +from spacy.cli.project.run import project_run +from .util import make_tempdir + + +def test_project_config_multiprocessing_good_case(): + project = { + "commands": [ + {"name": "command1", "script": ["echo", "command1"]}, + {"name": "command2", "script": ["echo", "command2"]}, + {"name": "command3", "script": ["echo", "command3"]}, + ], + "workflows": {"all": ["command1", {"parallel": ["command2", "command3"]}]}, + } + with make_tempdir() as d: + srsly.write_yaml(d / "project.yml", project) + load_project_config(d) + + +@pytest.mark.parametrize( + "workflows", + [ + {"all": ["command1", {"parallel": ["command2"]}, "command3"]}, + {"all": ["command1", {"parallel": ["command2", "command4"]}]}, + {"all": ["command1", {"parallel": ["command2", "command2"]}]}, + {"all": ["command1", {"serial": ["command2", "command3"]}]}, + ], +) +def test_project_config_multiprocessing_bad_case_workflows(workflows): + project = { + "commands": [ + {"name": "command1", "script": ["echo", "command1"]}, + {"name": "command2", "script": ["echo", "command2"]}, + {"name": "command3", "script": ["echo", "command3"]}, + ], + "workflows": workflows, + } + with make_tempdir() as d: + srsly.write_yaml(d / "project.yml", project) + with pytest.raises(SystemExit): + load_project_config(d) + + +@pytest.mark.parametrize("max_parallel_processes", [-1, 0, 1]) +def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_processes): + with make_tempdir() as d: + project = { + "max_parallel_processes": max_parallel_processes, + "commands": [ + { + "name": "commandA", + "script": [" ".join(("touch", os.sep.join((str(d), "A"))))], + }, + { + "name": "commandB", + "script": [" ".join(("touch", os.sep.join((str(d), "B"))))], + }, + { + "name": "commandC", + "script": [" ".join(("touch", os.sep.join((str(d), "C"))))], + }, + ], + "workflows": {"all": [{"parallel": ["commandA", "commandB", "commandC"]}]}, + } + with make_tempdir() as d: + srsly.write_yaml(d / "project.yml", project) + with pytest.raises(SystemExit): + load_project_config(d) + + +def test_project_run_multiprocessing_good_case(): + with make_tempdir() as d: + + pscript = """ +import sys, os +from time import sleep + +_, d_path, in_filename, out_filename, first_flag_filename, second_flag_filename = sys.argv +with open(os.sep.join((d_path, out_filename)), 'w') as out_file: + out_file.write("") +while True: + if os.path.exists(os.sep.join((d_path, in_filename))): + break + sleep(0.1) +if not os.path.exists(os.sep.join((d_path, first_flag_filename))): + with open(os.sep.join((d_path, first_flag_filename)), 'w') as first_flag_file: + first_flag_file.write("") +else: # should never happen because of skipping + with open(os.sep.join((d_path, second_flag_filename)), 'w') as second_flag_file: + second_flag_file.write("") + """ + + pscript_loc = os.sep.join((str(d), "pscript.py")) + with open(pscript_loc, "w") as pscript_file: + pscript_file.write(pscript) + os.chmod(pscript_loc, 0o777) + project = { + "commands": [ + { + "name": "commandA", + "script": [ + " ".join(("python", pscript_loc, str(d), "a", "b", "c", "d")) + ], + "outputs": [os.sep.join((str(d), "f"))], + }, + { + "name": "commandB", + "script": [ + " ".join(("python", pscript_loc, str(d), "b", "a", "e", "f")) + ], + "outputs": [os.sep.join((str(d), "e"))], + }, + ], + "workflows": { + "all": [ + {"parallel": ["commandA", "commandB"]}, + {"parallel": ["commandB", "commandA"]}, + ] + }, + } + srsly.write_yaml(d / "project.yml", project) + load_project_config(d) + project_run(d, "all") + assert os.path.exists(os.sep.join((str(d), "c"))) + assert os.path.exists(os.sep.join((str(d), "e"))) + assert not os.path.exists(os.sep.join((str(d), "d"))) + assert not os.path.exists(os.sep.join((str(d), "f"))) + + +@pytest.mark.parametrize("max_parallel_processes", [2, 3, 4, 5, 6]) +def test_project_run_multiprocessing_max_processes_good_case(max_parallel_processes): + with make_tempdir() as d: + + project = { + "max_parallel_processes": max_parallel_processes, + "commands": [ + { + "name": "commandA", + "script": [" ".join(("touch", os.sep.join((str(d), "A"))))], + }, + { + "name": "commandB", + "script": [" ".join(("touch", os.sep.join((str(d), "B"))))], + }, + { + "name": "commandC", + "script": [" ".join(("touch", os.sep.join((str(d), "C"))))], + }, + { + "name": "commandD", + "script": [" ".join(("touch", os.sep.join((str(d), "D"))))], + }, + { + "name": "commandE", + "script": [" ".join(("touch", os.sep.join((str(d), "E"))))], + }, + ], + "workflows": { + "all": [ + { + "parallel": [ + "commandA", + "commandB", + "commandC", + "commandD", + "commandE", + ] + } + ] + }, + } + srsly.write_yaml(d / "project.yml", project) + load_project_config(d) + project_run(d, "all") + assert os.path.exists(os.sep.join((str(d), "A"))) + assert os.path.exists(os.sep.join((str(d), "B"))) + assert os.path.exists(os.sep.join((str(d), "C"))) + assert os.path.exists(os.sep.join((str(d), "D"))) + assert os.path.exists(os.sep.join((str(d), "E")))