First draft of new implementation (incomplete, doesn't run yet)

This commit is contained in:
richardpaulhudson 2022-07-14 22:14:57 +02:00
parent e6a11b58ca
commit 9d7a79e305

View File

@ -1,9 +1,12 @@
from typing import Optional, List, Dict, Sequence, Any, Iterable, Union, Tuple
from typing import Optional, List, Dict, Sequence, Any, Iterable, Literal, Tuple, Union
from time import time
from os import kill, environ
from subprocess import Popen
from tempfile import TemporaryFile
from pathlib import Path
from multiprocessing import Process, Lock, Queue
from multiprocessing.synchronize import Lock as Lock_t
from wasabi import msg
from wasabi.util import locale_escape
from multiprocessing import Process, Queue
from wasabi import msg, Printer
from wasabi.util import locale_escape, supports_ansi
import sys
import srsly
import typer
@ -15,6 +18,7 @@ from ...util import SimpleFrozenList, is_minor_version_match, ENV_VARS
from ...util import check_bool_env_var, SimpleFrozenDict
from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash
from .._util import get_checksum, project_cli, Arg, Opt, COMMAND, parse_config_overrides
from ...errors import Errors
@project_cli.command(
@ -44,6 +48,19 @@ def project_run_cli(
project_run(project_dir, subcommand, overrides=overrides, force=force, dry=dry)
class MultiprocessingCommandInfo:
def __init__(self, cmd_name: str, cmd: str):
self.cmd_name = cmd_name
self.cmd = cmd
self.len_os_cmds = len(cmd["script"])
self.status: Literal[
"pending", "not rerun", "running", "hung", "succeeded", "failed", "killed"
] = "pending"
self.pid: Optional[int] = None
self.last_status_time: Optional[int] = None
self.os_cmd_index: Optional[int] = None
def project_run(
project_dir: Path,
subcommand: str,
@ -52,8 +69,6 @@ def project_run(
force: bool = False,
dry: bool = False,
capture: bool = False,
mult_group_mutex: Optional[Lock_t] = None,
completion_queue: Optional[Queue] = None,
) -> None:
"""Run a named script defined in the project.yml. If the script is part
of the default pipeline (defined in the "run" section), DVC is used to
@ -71,12 +86,9 @@ def project_run(
when you want to turn over execution to the command, and capture=True
when you want to run the command more like a function.
"""
if mult_group_mutex is None:
mult_group_mutex = Lock()
config = load_project_config(project_dir, overrides=overrides)
commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
workflows = config.get("workflows", {})
max_parallel_processes = config.get("max_parallel_processes")
validate_subcommand(list(commands.keys()), list(workflows.keys()), subcommand)
if subcommand in workflows:
msg.info(f"Running workflow '{subcommand}'")
@ -89,44 +101,18 @@ def project_run(
force=force,
dry=dry,
capture=capture,
mult_group_mutex=mult_group_mutex,
)
else:
assert isinstance(workflow_item, dict)
assert len(workflow_item) == 1
steps_list = workflow_item["parallel"]
assert isinstance(steps_list[0], str)
completion_queue = Queue(len(steps_list))
processes = [
Process(
target=project_run,
args=(project_dir, cmd),
kwargs={
"overrides": overrides,
"force": force,
"dry": dry,
"capture": capture,
"mult_group_mutex": mult_group_mutex,
"completion_queue": completion_queue,
},
)
for cmd in steps_list
]
num_processes = len(processes)
if (
max_parallel_processes is not None
and max_parallel_processes < num_processes
):
num_processes = max_parallel_processes
process_iterator = iter(processes)
for _ in range(num_processes):
next(process_iterator).start()
for _ in range(len(steps_list)):
completion_queue.get()
next_process = next(process_iterator, None)
if next_process is not None:
next_process.start()
cmds = workflow_item["parallel"]
assert isinstance(cmds[0], str)
project_run_mult_group(
project_dir,
[MultiprocessingCommandInfo(cmd, commands[cmd]) for cmd in cmds],
overrides=overrides,
force=force,
)
else:
cmd = commands[subcommand]
for dep in cmd.get("deps", []):
@ -142,16 +128,227 @@ def project_run(
current_dir,
cmd,
check_spacy_commit=check_spacy_commit,
mult_group_mutex=mult_group_mutex,
)
if not rerun and not force:
msg.info(f"Skipping '{cmd['name']}': nothing changed")
else:
run_commands(cmd["script"], dry=dry, capture=capture)
run_commands(
cmd["script"],
dry=dry,
capture=capture,
)
if not dry:
update_lockfile(current_dir, cmd, mult_group_mutex=mult_group_mutex)
if completion_queue is not None:
completion_queue.put(None)
update_lockfile(current_dir, cmd)
MULTIPROCESSING_GROUP_STATUS_INTERVAL = 1
MULTIPROCESSING_GROUP_LIVENESS_INTERVAL = 5
def project_run_mult_group(
project_dir: Path,
cmd_infos: List[MultiprocessingCommandInfo],
*,
overrides: Dict[str, Any] = SimpleFrozenDict(),
force: bool = False,
dry: bool = False,
) -> None:
config = load_project_config(project_dir, overrides=overrides)
mult_group_status_queue = Queue()
max_parallel_processes = config.get("max_parallel_processes")
check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION)
DISPLAY_STATUS = sys.stdout.isatty() and supports_ansi()
with working_dir(project_dir) as current_dir:
for cmd_info in cmd_infos:
for dep in cmd_info.cmd.get("deps", []):
if not (project_dir / dep).exists():
err = f"Missing dependency specified by command '{cmd_info.cmd_name}': {dep}"
err_help = "Maybe you forgot to run the 'project assets' command or a previous step?"
err_kwargs = {"exits": 1} if not dry else {}
msg.fail(err, err_help, **err_kwargs)
if not check_rerun(
current_dir, cmd_info.cmd, check_spacy_commit=check_spacy_commit
):
cmd_info.status = "not rerun"
processes = [
Process(
target=project_run_mult_command,
args=(cmd_info,),
kwargs={
"dry": dry,
"cmd_ind": cmd_ind,
"mult_group_status_queue": mult_group_status_queue,
},
)
for cmd_ind, cmd_info in enumerate(
c for c in cmd_infos if c.status != "not rerun"
)
]
num_processes = len(processes)
if (
max_parallel_processes is not None
and max_parallel_processes < num_processes
):
num_processes = max_parallel_processes
process_iterator = iter(processes)
for _ in range(num_processes):
next(process_iterator).start()
msg.divider("MULTIPROCESSING GROUP")
if not DISPLAY_STATUS:
print(
"parallel[" + ", ".join(cmd_info.name for cmd_info in cmd_infos) + "]"
)
first = True
while any(cmd_info.status in ("pending", "running") for cmd_info in cmd_infos):
try:
mess: Dict[str, Union[str, int]] = mult_group_status_queue.get(
MULTIPROCESSING_GROUP_LIVENESS_INTERVAL
)
except mult_group_status_queue.Empty:
for other_cmd_info in (c for c in cmd_infos if c.status == "running"):
other_cmd_info.status = "hung"
cmd_info = cmd_infos[mess["cmd_ind"]]
if mess["status"] in ("start", "alive"):
cmd_info.last_status_time = int(time())
for other_cmd_info in (c for c in cmd_infos if c.status == "running"):
if (
other_cmd_info.last_status_time is not None
and time() - other_cmd_info.last_status_time
> MULTIPROCESSING_GROUP_LIVENESS_INTERVAL
):
other_cmd_info.status = "hung"
if mess["status"] == "start":
cmd_info.status = "running"
cmd_info.os_cmd_index = mess["os_cmd_index"]
elif mess["status"] == "succeeded":
cmd_info.status = "succeeded"
if not dry:
update_lockfile(current_dir, cmd_info.cmd)
next_process = next(process_iterator, None)
if next_process is not None:
next_process.start()
elif mess[1] == "failed":
cmd_info.status = "failed"
if any(c for c in cmd_infos if c.status in ("failed", "hung")):
for other_cmd_info in (c for c in cmd_infos if c.status == "running"):
kill(other_cmd_info.pid, -15)
if mess["status"] in ("succeeded", "failed", "killed"):
cmd_info.output = mess["output"]
if mess["status"] != "alive" and DISPLAY_STATUS:
if first:
first = False
else:
print("\033[F" * (2 + len(cmd_infos)))
data = [
[c.cmd_name, c.status, f"{c.os_cmd_ind}/{c.len_os_cmds}"]
for c in cmd_infos
]
header = ["Command", "Status", "OS Command"]
msg.table(data, header=header)
for cmd_info in cmd_infos:
print(cmd_info.output)
group_rc = max(c.rc for c in cmd_infos)
if group_rc <= 0:
group_rc = min(c.rc for c in cmd_infos)
if group_rc != 0:
sys.exit(group_rc)
def project_run_mult_command(
cmd_info: MultiprocessingCommandInfo,
*dry: bool,
cmd_ind: int,
mult_group_status_queue: Queue,
) -> None:
printer = Printer(no_print=True)
with TemporaryFile() as logfile:
print(printer.divider(cmd_info.cmd_name), logfile)
for os_cmd_ind, os_cmd in enumerate(cmd_info.cmd["script"]):
command = split_command(os_cmd)
if len(command) and command[0] in ("python", "python3"):
command = [sys.executable, "-u", *command[1:]]
elif len(command) and command[0] in ("pip", "pip3"):
command = [sys.executable, "-m", "pip", *command[1:]]
print(f"Running command: {join_command(command)}", logfile)
if not dry:
try:
sp = Popen(
command, stdout=logfile, env=environ.copy(), encoding="utf8"
)
except FileNotFoundError:
# Indicates the *command* wasn't found, it's an error before the command
# is run.
raise FileNotFoundError(
Errors.E970.format(
str_command=" ".join(command), tool=command[0]
)
) from None
mult_group_status_queue.put(
{
"cmd_ind": cmd_ind,
"os_cmd_ind": os_cmd_ind,
"status": "start",
"pid": sp.pid,
}
)
while True:
sp.communicate(MULTIPROCESSING_GROUP_STATUS_INTERVAL)
rc = sp.poll()
if rc == None:
mult_group_status_queue.put(
{
"cmd_ind": cmd_ind,
"status": "alive",
}
)
else:
break
if rc != 0:
status = "failed" if rc > 0 else "killed"
print(printer.divider(status.upper() + f" rc={rc}"), logfile)
mult_group_status_queue.put(
{"cmd_ind": cmd_ind, "status": status, "output": logfile.read()}
)
break
mult_group_status_queue.put(
{"cmd_ind": cmd_ind, "status": "succeeded", "output": logfile.read()}
)
def run_commands(
commands: Iterable[str] = SimpleFrozenList(),
silent: bool = False,
dry: bool = False,
capture: bool = False,
) -> None:
"""Run a sequence of commands in a subprocess, in order.
commands (List[str]): The string commands.
silent (bool): Don't print the commands.
dry (bool): Perform a dry run and don't execute anything.
capture (bool): Whether to capture the output and errors of individual commands.
If False, the stdout and stderr will not be redirected, and if there's an error,
sys.exit will be called with the return code. You should use capture=False
when you want to turn over execution to the command, and capture=True
when you want to run the command more like a function.
"""
for c in commands:
command = split_command(c)
# Not sure if this is needed or a good idea. Motivation: users may often
# use commands in their config that reference "python" and we want to
# make sure that it's always executing the same Python that spaCy is
# executed with and the pip in the same env, not some other Python/pip.
# Also ensures cross-compatibility if user 1 writes "python3" (because
# that's how it's set up on their system), and user 2 without the
# shortcut tries to re-run the command.
if len(command) and command[0] in ("python", "python3"):
command[0] = sys.executable
elif len(command) and command[0] in ("pip", "pip3"):
command = [sys.executable, "-m", "pip", *command[1:]]
if not silent:
print(f"Running command: {join_command(command)}")
if not dry:
run_command(command, capture=capture)
def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
@ -300,7 +497,6 @@ def check_rerun(
*,
check_spacy_version: bool = True,
check_spacy_commit: bool = False,
mult_group_mutex: Lock_t,
) -> bool:
"""Check if a command should be rerun because its settings or inputs/outputs
changed.
@ -313,44 +509,38 @@ def check_rerun(
# Always rerun if no-skip is set
if command.get("no_skip", False):
return True
with mult_group_mutex:
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists(): # We don't have a lockfile, run command
return True
data = srsly.read_yaml(lock_path)
if command["name"] not in data: # We don't have info about this command
return True
entry = data[command["name"]]
# Always run commands with no outputs (otherwise they'd always be skipped)
if not entry.get("outs", []):
return True
# Always rerun if spaCy version or commit hash changed
spacy_v = entry.get("spacy_version")
commit = entry.get("spacy_git_version")
if check_spacy_version and not is_minor_version_match(
spacy_v, about.__version__
):
info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)"
msg.info(
f"Re-running '{command['name']}': spaCy minor version changed {info}"
)
return True
if check_spacy_commit and commit != GIT_VERSION:
info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)"
msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}")
return True
# If the entry in the lockfile matches the lockfile entry that would be
# generated from the current command, we don't rerun because it means that
# all inputs/outputs, hashes and scripts are the same and nothing changed
lock_entry = get_lock_entry(project_dir, command)
exclude = ["spacy_version", "spacy_git_version"]
return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude)
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists(): # We don't have a lockfile, run command
return True
data = srsly.read_yaml(lock_path)
if command["name"] not in data: # We don't have info about this command
return True
entry = data[command["name"]]
# Always run commands with no outputs (otherwise they'd always be skipped)
if not entry.get("outs", []):
return True
# Always rerun if spaCy version or commit hash changed
spacy_v = entry.get("spacy_version")
commit = entry.get("spacy_git_version")
if check_spacy_version and not is_minor_version_match(spacy_v, about.__version__):
info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)"
msg.info(f"Re-running '{command['name']}': spaCy minor version changed {info}")
return True
if check_spacy_commit and commit != GIT_VERSION:
info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)"
msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}")
return True
# If the entry in the lockfile matches the lockfile entry that would be
# generated from the current command, we don't rerun because it means that
# all inputs/outputs, hashes and scripts are the same and nothing changed
lock_entry = get_lock_entry(project_dir, command)
exclude = ["spacy_version", "spacy_git_version"]
return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude)
def update_lockfile(
project_dir: Path,
command: Dict[str, Any],
mult_group_mutex: Lock_t,
) -> None:
"""Update the lockfile after running a command. Will create a lockfile if
it doesn't yet exist and will add an entry for the current command, its
@ -358,17 +548,15 @@ def update_lockfile(
project_dir (Path): The current project directory.
command (Dict[str, Any]): The command, as defined in the project.yml.
mult_group_mutex: the mutex preventing concurrent writes
"""
with mult_group_mutex:
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists():
srsly.write_yaml(lock_path, {})
data = {}
else:
data = srsly.read_yaml(lock_path)
data[command["name"]] = get_lock_entry(project_dir, command)
srsly.write_yaml(lock_path, data)
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists():
srsly.write_yaml(lock_path, {})
data = {}
else:
data = srsly.read_yaml(lock_path)
data[command["name"]] = get_lock_entry(project_dir, command)
srsly.write_yaml(lock_path, data)
def get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]: