From 12e86004c88052cf03cc9d2fd9c8f076d9a569f1 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Mon, 9 May 2022 17:30:26 +0200 Subject: [PATCH] Basic multiprocessing functionality --- spacy/cli/_util.py | 8 ++- spacy/cli/project/run.py | 134 +++++++++++++++++++++++++-------------- spacy/tests/test_cli.py | 56 ++++++++++++++++ 3 files changed, 148 insertions(+), 50 deletions(-) diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index ce9c2a5d3..54a868953 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -249,10 +249,16 @@ def validate_project_commands(config: Dict[str, Any]) -> None: workflow_list = cast(List[str], step_or_list) if len(workflow_list) < 2: msg.fail( - f"Invalid multiprocessing group within '{workflow_name}'", + f"Invalid multiprocessing group within '{workflow_name}'.", f"A multiprocessing group must reference at least two commands.", exits=1, ) + if len(workflow_list) != len(set(workflow_list)): + msg.fail( + f"A multiprocessing group within '{workflow_name}' contains a command more than once.", + f"This is not permitted because it is then not possible to determine when to rerun.", + exits=1, + ) for step in workflow_list: verify_workflow_step(step) diff --git a/spacy/cli/project/run.py b/spacy/cli/project/run.py index 734803bc4..24b827083 100644 --- a/spacy/cli/project/run.py +++ b/spacy/cli/project/run.py @@ -1,5 +1,6 @@ -from typing import Optional, List, Dict, Sequence, Any, Iterable +from typing import Optional, List, Dict, Sequence, Any, Iterable, cast from pathlib import Path +from multiprocessing import Process, Lock from wasabi import msg from wasabi.util import locale_escape import sys @@ -50,6 +51,7 @@ def project_run( force: bool = False, dry: bool = False, capture: bool = False, + mult_group_mutex: Optional[Lock] = None, ) -> None: """Run a named script defined in the project.yml. If the script is part of the default pipeline (defined in the "run" section), DVC is used to @@ -67,21 +69,40 @@ def project_run( when you want to turn over execution to the command, and capture=True when you want to run the command more like a function. """ + if mult_group_mutex is None: + mult_group_mutex = Lock() config = load_project_config(project_dir, overrides=overrides) commands = {cmd["name"]: cmd for cmd in config.get("commands", [])} workflows = config.get("workflows", {}) validate_subcommand(list(commands.keys()), list(workflows.keys()), subcommand) if subcommand in workflows: msg.info(f"Running workflow '{subcommand}'") - for cmd in workflows[subcommand]: - project_run( - project_dir, - cmd, - overrides=overrides, - force=force, - dry=dry, - capture=capture, - ) + for cmdOrMultiprocessingGroup in workflows[subcommand]: + if isinstance(cmdOrMultiprocessingGroup, str): + project_run( + project_dir, + cmdOrMultiprocessingGroup, + overrides=overrides, + force=force, + dry=dry, + capture=capture, + mult_group_mutex=mult_group_mutex, + ) + else: + processes = [Process( + target=project_run, + args=(project_dir, cmd), + kwargs={ + "overrides": overrides, + "force": force, + "dry": dry, + "capture": capture, + "mult_group_mutex": mult_group_mutex, + }) for cmd in cmdOrMultiprocessingGroup] + for process in processes: + process.start() + for process in processes: + process.join() else: cmd = commands[subcommand] for dep in cmd.get("deps", []): @@ -93,13 +114,18 @@ def project_run( check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION) with working_dir(project_dir) as current_dir: msg.divider(subcommand) - rerun = check_rerun(current_dir, cmd, check_spacy_commit=check_spacy_commit) + rerun = check_rerun( + current_dir, + cmd, + check_spacy_commit=check_spacy_commit, + mult_group_mutex=mult_group_mutex, + ) if not rerun and not force: msg.info(f"Skipping '{cmd['name']}': nothing changed") else: run_commands(cmd["script"], dry=dry, capture=capture) if not dry: - update_lockfile(current_dir, cmd) + update_lockfile(current_dir, cmd, mult_group_mutex=mult_group_mutex) def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None: @@ -157,7 +183,7 @@ def run_commands( commands (List[str]): The string commands. silent (bool): Don't print the commands. - dry (bool): Perform a dry run and don't execut anything. + dry (bool): Perform a dry run and don't execute anything. capture (bool): Whether to capture the output and errors of individual commands. If False, the stdout and stderr will not be redirected, and if there's an error, sys.exit will be called with the return code. You should use capture=False @@ -212,6 +238,7 @@ def check_rerun( *, check_spacy_version: bool = True, check_spacy_commit: bool = False, + mult_group_mutex: Lock, ) -> bool: """Check if a command should be rerun because its settings or inputs/outputs changed. @@ -224,51 +251,60 @@ def check_rerun( # Always rerun if no-skip is set if command.get("no_skip", False): return True - lock_path = project_dir / PROJECT_LOCK - if not lock_path.exists(): # We don't have a lockfile, run command - return True - data = srsly.read_yaml(lock_path) - if command["name"] not in data: # We don't have info about this command - return True - entry = data[command["name"]] - # Always run commands with no outputs (otherwise they'd always be skipped) - if not entry.get("outs", []): - return True - # Always rerun if spaCy version or commit hash changed - spacy_v = entry.get("spacy_version") - commit = entry.get("spacy_git_version") - if check_spacy_version and not is_minor_version_match(spacy_v, about.__version__): - info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)" - msg.info(f"Re-running '{command['name']}': spaCy minor version changed {info}") - return True - if check_spacy_commit and commit != GIT_VERSION: - info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)" - msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}") - return True - # If the entry in the lockfile matches the lockfile entry that would be - # generated from the current command, we don't rerun because it means that - # all inputs/outputs, hashes and scripts are the same and nothing changed - lock_entry = get_lock_entry(project_dir, command) - exclude = ["spacy_version", "spacy_git_version"] - return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude) + with mult_group_mutex: + lock_path = project_dir / PROJECT_LOCK + if not lock_path.exists(): # We don't have a lockfile, run command + return True + data = srsly.read_yaml(lock_path) + if command["name"] not in data: # We don't have info about this command + return True + entry = data[command["name"]] + # Always run commands with no outputs (otherwise they'd always be skipped) + if not entry.get("outs", []): + return True + # Always rerun if spaCy version or commit hash changed + spacy_v = entry.get("spacy_version") + commit = entry.get("spacy_git_version") + if check_spacy_version and not is_minor_version_match( + spacy_v, about.__version__ + ): + info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)" + msg.info( + f"Re-running '{command['name']}': spaCy minor version changed {info}" + ) + return True + if check_spacy_commit and commit != GIT_VERSION: + info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)" + msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}") + return True + # If the entry in the lockfile matches the lockfile entry that would be + # generated from the current command, we don't rerun because it means that + # all inputs/outputs, hashes and scripts are the same and nothing changed + lock_entry = get_lock_entry(project_dir, command) + exclude = ["spacy_version", "spacy_git_version"] + return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude) -def update_lockfile(project_dir: Path, command: Dict[str, Any]) -> None: +def update_lockfile( + project_dir: Path, command: Dict[str, Any], mult_group_mutex: Lock +) -> None: """Update the lockfile after running a command. Will create a lockfile if it doesn't yet exist and will add an entry for the current command, its script and dependencies/outputs. project_dir (Path): The current project directory. command (Dict[str, Any]): The command, as defined in the project.yml. + mult_group_mutex: the mutex preventing concurrent writes """ - lock_path = project_dir / PROJECT_LOCK - if not lock_path.exists(): - srsly.write_yaml(lock_path, {}) - data = {} - else: - data = srsly.read_yaml(lock_path) - data[command["name"]] = get_lock_entry(project_dir, command) - srsly.write_yaml(lock_path, data) + with mult_group_mutex: + lock_path = project_dir / PROJECT_LOCK + if not lock_path.exists(): + srsly.write_yaml(lock_path, {}) + data = {} + else: + data = srsly.read_yaml(lock_path) + data[command["name"]] = get_lock_entry(project_dir, command) + srsly.write_yaml(lock_path, data) def get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]: diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 86583b1ab..310bc97b8 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -19,6 +19,7 @@ from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config from spacy.cli.package import get_third_party_dependencies from spacy.cli.package import _is_permitted_package_name from spacy.cli.validate import get_model_pkgs +from spacy.cli.project.run import project_run from spacy.lang.en import English from spacy.lang.nl import Dutch from spacy.language import Language @@ -440,6 +441,7 @@ def test_project_config_multiprocessing_good_case(): [ {"all": ["command1", ["command2"], "command3"]}, {"all": ["command1", ["command2", "command4"]]}, + {"all": ["command1", ["command2", "command2"]]}, ], ) def test_project_config_multiprocessing_bad_case(workflows): @@ -457,6 +459,60 @@ def test_project_config_multiprocessing_bad_case(workflows): load_project_config(d) +def test_project_run_multiprocessing_good_case(): + with make_tempdir() as d: + + pscript = """ +import sys, os +from time import sleep + +_, d_path, in_filename, out_filename, first_flag_filename, second_flag_filename = sys.argv +with open(os.sep.join((d_path, out_filename)), 'w') as out_file: + out_file.write("") +while True: + if os.path.exists(os.sep.join((d_path, in_filename))): + break + sleep(0.1) +if not os.path.exists(os.sep.join((d_path, first_flag_filename))): + with open(os.sep.join((d_path, first_flag_filename)), 'w') as first_flag_file: + first_flag_file.write("") +else: # should never happen because of skipping + with open(os.sep.join((d_path, second_flag_filename)), 'w') as second_flag_file: + second_flag_file.write("") + """ + + pscript_loc = os.sep.join((str(d), "pscript.py")) + with open(pscript_loc, "w") as pscript_file: + pscript_file.write(pscript) + os.chmod(pscript_loc, 0o777) + project = { + "commands": [ + { + "name": "commandA", + "script": [ + " ".join(("python", pscript_loc, str(d), "a", "b", "c", "d")) + ], + "outputs": [" ".join((str(d), "c"))], + }, + { + "name": "commandB", + "script": [ + " ".join(("python", pscript_loc, str(d), "b", "a", "e", "f")) + ], + "outputs": [" ".join((str(d), "e"))], + }, + ], + "workflows": {"all": [["commandA", "commandB"], ["commandA", "commandB"]]}, + } + srsly.write_yaml(d / "project.yml", project) + load_project_config(d) + project_run(d, "all") + assert os.path.exists(os.sep.join((str(d), "c"))) + assert os.path.exists(os.sep.join((str(d), "e"))) + assert not os.path.exists(os.sep.join((str(d), "d"))) + assert not os.path.exists(os.sep.join((str(d), "f"))) + + @pytest.mark.parametrize( "greeting", [342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],