Basic multiprocessing functionality

This commit is contained in:
richardpaulhudson 2022-05-09 17:30:26 +02:00
parent 8d08a68174
commit 12e86004c8
3 changed files with 148 additions and 50 deletions

View File

@ -249,10 +249,16 @@ def validate_project_commands(config: Dict[str, Any]) -> None:
workflow_list = cast(List[str], step_or_list)
if len(workflow_list) < 2:
msg.fail(
f"Invalid multiprocessing group within '{workflow_name}'",
f"Invalid multiprocessing group within '{workflow_name}'.",
f"A multiprocessing group must reference at least two commands.",
exits=1,
)
if len(workflow_list) != len(set(workflow_list)):
msg.fail(
f"A multiprocessing group within '{workflow_name}' contains a command more than once.",
f"This is not permitted because it is then not possible to determine when to rerun.",
exits=1,
)
for step in workflow_list:
verify_workflow_step(step)

View File

@ -1,5 +1,6 @@
from typing import Optional, List, Dict, Sequence, Any, Iterable
from typing import Optional, List, Dict, Sequence, Any, Iterable, cast
from pathlib import Path
from multiprocessing import Process, Lock
from wasabi import msg
from wasabi.util import locale_escape
import sys
@ -50,6 +51,7 @@ def project_run(
force: bool = False,
dry: bool = False,
capture: bool = False,
mult_group_mutex: Optional[Lock] = None,
) -> None:
"""Run a named script defined in the project.yml. If the script is part
of the default pipeline (defined in the "run" section), DVC is used to
@ -67,21 +69,40 @@ def project_run(
when you want to turn over execution to the command, and capture=True
when you want to run the command more like a function.
"""
if mult_group_mutex is None:
mult_group_mutex = Lock()
config = load_project_config(project_dir, overrides=overrides)
commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
workflows = config.get("workflows", {})
validate_subcommand(list(commands.keys()), list(workflows.keys()), subcommand)
if subcommand in workflows:
msg.info(f"Running workflow '{subcommand}'")
for cmd in workflows[subcommand]:
project_run(
project_dir,
cmd,
overrides=overrides,
force=force,
dry=dry,
capture=capture,
)
for cmdOrMultiprocessingGroup in workflows[subcommand]:
if isinstance(cmdOrMultiprocessingGroup, str):
project_run(
project_dir,
cmdOrMultiprocessingGroup,
overrides=overrides,
force=force,
dry=dry,
capture=capture,
mult_group_mutex=mult_group_mutex,
)
else:
processes = [Process(
target=project_run,
args=(project_dir, cmd),
kwargs={
"overrides": overrides,
"force": force,
"dry": dry,
"capture": capture,
"mult_group_mutex": mult_group_mutex,
}) for cmd in cmdOrMultiprocessingGroup]
for process in processes:
process.start()
for process in processes:
process.join()
else:
cmd = commands[subcommand]
for dep in cmd.get("deps", []):
@ -93,13 +114,18 @@ def project_run(
check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION)
with working_dir(project_dir) as current_dir:
msg.divider(subcommand)
rerun = check_rerun(current_dir, cmd, check_spacy_commit=check_spacy_commit)
rerun = check_rerun(
current_dir,
cmd,
check_spacy_commit=check_spacy_commit,
mult_group_mutex=mult_group_mutex,
)
if not rerun and not force:
msg.info(f"Skipping '{cmd['name']}': nothing changed")
else:
run_commands(cmd["script"], dry=dry, capture=capture)
if not dry:
update_lockfile(current_dir, cmd)
update_lockfile(current_dir, cmd, mult_group_mutex=mult_group_mutex)
def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
@ -157,7 +183,7 @@ def run_commands(
commands (List[str]): The string commands.
silent (bool): Don't print the commands.
dry (bool): Perform a dry run and don't execut anything.
dry (bool): Perform a dry run and don't execute anything.
capture (bool): Whether to capture the output and errors of individual commands.
If False, the stdout and stderr will not be redirected, and if there's an error,
sys.exit will be called with the return code. You should use capture=False
@ -212,6 +238,7 @@ def check_rerun(
*,
check_spacy_version: bool = True,
check_spacy_commit: bool = False,
mult_group_mutex: Lock,
) -> bool:
"""Check if a command should be rerun because its settings or inputs/outputs
changed.
@ -224,51 +251,60 @@ def check_rerun(
# Always rerun if no-skip is set
if command.get("no_skip", False):
return True
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists(): # We don't have a lockfile, run command
return True
data = srsly.read_yaml(lock_path)
if command["name"] not in data: # We don't have info about this command
return True
entry = data[command["name"]]
# Always run commands with no outputs (otherwise they'd always be skipped)
if not entry.get("outs", []):
return True
# Always rerun if spaCy version or commit hash changed
spacy_v = entry.get("spacy_version")
commit = entry.get("spacy_git_version")
if check_spacy_version and not is_minor_version_match(spacy_v, about.__version__):
info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)"
msg.info(f"Re-running '{command['name']}': spaCy minor version changed {info}")
return True
if check_spacy_commit and commit != GIT_VERSION:
info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)"
msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}")
return True
# If the entry in the lockfile matches the lockfile entry that would be
# generated from the current command, we don't rerun because it means that
# all inputs/outputs, hashes and scripts are the same and nothing changed
lock_entry = get_lock_entry(project_dir, command)
exclude = ["spacy_version", "spacy_git_version"]
return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude)
with mult_group_mutex:
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists(): # We don't have a lockfile, run command
return True
data = srsly.read_yaml(lock_path)
if command["name"] not in data: # We don't have info about this command
return True
entry = data[command["name"]]
# Always run commands with no outputs (otherwise they'd always be skipped)
if not entry.get("outs", []):
return True
# Always rerun if spaCy version or commit hash changed
spacy_v = entry.get("spacy_version")
commit = entry.get("spacy_git_version")
if check_spacy_version and not is_minor_version_match(
spacy_v, about.__version__
):
info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)"
msg.info(
f"Re-running '{command['name']}': spaCy minor version changed {info}"
)
return True
if check_spacy_commit and commit != GIT_VERSION:
info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)"
msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}")
return True
# If the entry in the lockfile matches the lockfile entry that would be
# generated from the current command, we don't rerun because it means that
# all inputs/outputs, hashes and scripts are the same and nothing changed
lock_entry = get_lock_entry(project_dir, command)
exclude = ["spacy_version", "spacy_git_version"]
return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude)
def update_lockfile(project_dir: Path, command: Dict[str, Any]) -> None:
def update_lockfile(
project_dir: Path, command: Dict[str, Any], mult_group_mutex: Lock
) -> None:
"""Update the lockfile after running a command. Will create a lockfile if
it doesn't yet exist and will add an entry for the current command, its
script and dependencies/outputs.
project_dir (Path): The current project directory.
command (Dict[str, Any]): The command, as defined in the project.yml.
mult_group_mutex: the mutex preventing concurrent writes
"""
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists():
srsly.write_yaml(lock_path, {})
data = {}
else:
data = srsly.read_yaml(lock_path)
data[command["name"]] = get_lock_entry(project_dir, command)
srsly.write_yaml(lock_path, data)
with mult_group_mutex:
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists():
srsly.write_yaml(lock_path, {})
data = {}
else:
data = srsly.read_yaml(lock_path)
data[command["name"]] = get_lock_entry(project_dir, command)
srsly.write_yaml(lock_path, data)
def get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]:

View File

@ -19,6 +19,7 @@ from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
from spacy.cli.package import get_third_party_dependencies
from spacy.cli.package import _is_permitted_package_name
from spacy.cli.validate import get_model_pkgs
from spacy.cli.project.run import project_run
from spacy.lang.en import English
from spacy.lang.nl import Dutch
from spacy.language import Language
@ -440,6 +441,7 @@ def test_project_config_multiprocessing_good_case():
[
{"all": ["command1", ["command2"], "command3"]},
{"all": ["command1", ["command2", "command4"]]},
{"all": ["command1", ["command2", "command2"]]},
],
)
def test_project_config_multiprocessing_bad_case(workflows):
@ -457,6 +459,60 @@ def test_project_config_multiprocessing_bad_case(workflows):
load_project_config(d)
def test_project_run_multiprocessing_good_case():
with make_tempdir() as d:
pscript = """
import sys, os
from time import sleep
_, d_path, in_filename, out_filename, first_flag_filename, second_flag_filename = sys.argv
with open(os.sep.join((d_path, out_filename)), 'w') as out_file:
out_file.write("")
while True:
if os.path.exists(os.sep.join((d_path, in_filename))):
break
sleep(0.1)
if not os.path.exists(os.sep.join((d_path, first_flag_filename))):
with open(os.sep.join((d_path, first_flag_filename)), 'w') as first_flag_file:
first_flag_file.write("")
else: # should never happen because of skipping
with open(os.sep.join((d_path, second_flag_filename)), 'w') as second_flag_file:
second_flag_file.write("")
"""
pscript_loc = os.sep.join((str(d), "pscript.py"))
with open(pscript_loc, "w") as pscript_file:
pscript_file.write(pscript)
os.chmod(pscript_loc, 0o777)
project = {
"commands": [
{
"name": "commandA",
"script": [
" ".join(("python", pscript_loc, str(d), "a", "b", "c", "d"))
],
"outputs": [" ".join((str(d), "c"))],
},
{
"name": "commandB",
"script": [
" ".join(("python", pscript_loc, str(d), "b", "a", "e", "f"))
],
"outputs": [" ".join((str(d), "e"))],
},
],
"workflows": {"all": [["commandA", "commandB"], ["commandA", "commandB"]]},
}
srsly.write_yaml(d / "project.yml", project)
load_project_config(d)
project_run(d, "all")
assert os.path.exists(os.sep.join((str(d), "c")))
assert os.path.exists(os.sep.join((str(d), "e")))
assert not os.path.exists(os.sep.join((str(d), "d")))
assert not os.path.exists(os.sep.join((str(d), "f")))
@pytest.mark.parametrize(
"greeting",
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],