Refactoring into separate module

This commit is contained in:
richardpaulhudson 2022-07-19 14:08:46 +02:00
parent e9ee680873
commit 4c2fc56a5b
6 changed files with 359 additions and 362 deletions

View File

@ -18,8 +18,12 @@ import os
from ..compat import Literal
from ..schemas import ProjectConfigSchema, validate
from ..git_info import GIT_VERSION
from ..util import import_file, run_command, make_tempdir, registry, logger
from ..util import is_compatible_version, SimpleFrozenDict, ENV_VARS
from ..util import is_minor_version_match
from .. import about
if TYPE_CHECKING:
@ -598,3 +602,121 @@ def setup_gpu(use_gpu: int, silent=None) -> None:
local_msg.info("Using CPU")
if has_cupy and gpu_is_available():
local_msg.info("To switch to GPU 0, use the option: --gpu-id 0")
def check_rerun(
project_dir: Path,
command: Dict[str, Any],
*,
check_spacy_version: bool = True,
check_spacy_commit: bool = False,
) -> bool:
"""Check if a command should be rerun because its settings or inputs/outputs
changed.
project_dir (Path): The current project directory.
command (Dict[str, Any]): The command, as defined in the project.yml.
strict_version (bool):
RETURNS (bool): Whether to re-run the command.
"""
# Always rerun if no-skip is set
if command.get("no_skip", False):
return True
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists(): # We don't have a lockfile, run command
return True
data = srsly.read_yaml(lock_path)
if command["name"] not in data: # We don't have info about this command
return True
entry = data[command["name"]]
# Always run commands with no outputs (otherwise they'd always be skipped)
if not entry.get("outs", []):
return True
# Always rerun if spaCy version or commit hash changed
spacy_v = entry.get("spacy_version")
commit = entry.get("spacy_git_version")
if check_spacy_version and not is_minor_version_match(spacy_v, about.__version__):
info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)"
msg.info(f"Re-running '{command['name']}': spaCy minor version changed {info}")
return True
if check_spacy_commit and commit != GIT_VERSION:
info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)"
msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}")
return True
# If the entry in the lockfile matches the lockfile entry that would be
# generated from the current command, we don't rerun because it means that
# all inputs/outputs, hashes and scripts are the same and nothing changed
lock_entry = _get_lock_entry(project_dir, command)
exclude = ["spacy_version", "spacy_git_version"]
return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude)
def update_lockfile(
project_dir: Path,
command: Dict[str, Any],
) -> None:
"""Update the lockfile after running a command. Will create a lockfile if
it doesn't yet exist and will add an entry for the current command, its
script and dependencies/outputs.
project_dir (Path): The current project directory.
command (Dict[str, Any]): The command, as defined in the project.yml.
"""
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists():
srsly.write_yaml(lock_path, {})
data = {}
else:
data = srsly.read_yaml(lock_path)
data[command["name"]] = _get_lock_entry(project_dir, command)
srsly.write_yaml(lock_path, data)
def _get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]:
"""Get a lockfile entry for a given command. An entry includes the command,
the script (command steps) and a list of dependencies and outputs with
their paths and file hashes, if available. The format is based on the
dvc.lock files, to keep things consistent.
project_dir (Path): The current project directory.
command (Dict[str, Any]): The command, as defined in the project.yml.
RETURNS (Dict[str, Any]): The lockfile entry.
"""
deps = _get_fileinfo(project_dir, command.get("deps", []))
outs = _get_fileinfo(project_dir, command.get("outputs", []))
outs_nc = _get_fileinfo(project_dir, command.get("outputs_no_cache", []))
return {
"cmd": f"{COMMAND} run {command['name']}",
"script": command["script"],
"deps": deps,
"outs": [*outs, *outs_nc],
"spacy_version": about.__version__,
"spacy_git_version": GIT_VERSION,
}
def _get_fileinfo(
project_dir: Path, paths: List[str]
) -> List[Dict[str, Optional[str]]]:
"""Generate the file information for a list of paths (dependencies, outputs).
Includes the file path and the file's checksum.
project_dir (Path): The current project directory.
paths (List[str]): The file paths.
RETURNS (List[Dict[str, str]]): The lockfile entry for a file.
"""
data = []
for path in paths:
file_path = project_dir / path
md5 = get_checksum(file_path) if file_path.exists() else None
data.append({"path": path, "md5": md5})
return data
def check_deps(cmd: Dict, cmd_name: str, project_dir: Path, dry: bool):
for dep in cmd.get("deps", []):
if not (project_dir / dep).exists():
err = f"Missing dependency specified by command '{cmd_name}': {dep}"
err_help = "Maybe you forgot to run the 'project assets' command or a previous step?"
err_kwargs = {"exits": 1} if not dry else {}
msg.fail(err, err_help, **err_kwargs)

View File

@ -1,4 +1,5 @@
from typing import Any, List, Literal, Optional, Dict
from typing import Any, List, Literal, Optional, Dict, Union, cast
import sys
from pathlib import Path
from time import time
import multiprocessing
@ -7,14 +8,15 @@ from os import kill, environ, linesep, mkdir, sep
from shutil import rmtree
from signal import SIGTERM
from subprocess import STDOUT, Popen, TimeoutExpired
from wasabi import msg
from wasabi.util import color, supports_ansi
from ...util import SimpleFrozenDict, load_project_config, working_dir
from ...util import check_bool_env_var
from .._util import check_rerun, check_deps, update_lockfile, load_project_config
from ...util import SimpleFrozenDict, working_dir
from ...util import check_bool_env_var, ENV_VARS, split_command, join_command
from ...errors import Errors
# How often the worker processes managing the commands in a parallel group
# send keepalive messages to the main processes
PARALLEL_GROUP_STATUS_INTERVAL = 1
@ -34,8 +36,10 @@ DISPLAY_STATUS_COLORS = {
"hung": "red",
"cancelled": "red",
}
class ParallelCommandInfo:
def __init__(self, cmd_name: str, cmd: str, cmd_ind: int):
class _ParallelCommandInfo:
def __init__(self, cmd_name: str, cmd: Dict, cmd_ind: int):
self.cmd_name = cmd_name
self.cmd = cmd
self.cmd_ind = cmd_ind
@ -59,21 +63,26 @@ class ParallelCommandInfo:
@property
def disp_status(self) -> str:
status = self.status
status_color = DISPLAY_STATUS_COLORS[status]
if status == "running":
status = f"{status} ({self.os_cmd_ind + 1}/{self.len_os_cmds})"
return color(status, status_color)
status_str = str(self.status)
status_color = DISPLAY_STATUS_COLORS[status_str]
if status_str == "running" and self.os_cmd_ind is not None:
status_str = f"{status_str} ({self.os_cmd_ind + 1}/{self.len_os_cmds})"
return color(status_str, status_color)
def start_process(process:Process, proc_to_cmd_infos: Dict[Process, ParallelCommandInfo]) -> None:
def _start_process(
process: Process, proc_to_cmd_infos: Dict[Process, _ParallelCommandInfo]
) -> None:
cmd_info = proc_to_cmd_infos[process]
if cmd_info.status == "pending":
cmd_info.status = "starting"
cmd_info.last_status_time = int(time())
process.start()
def project_run_parallel_group(
project_dir: Path,
cmd_infos: List[ParallelCommandInfo],
cmd_names: List[str],
*,
overrides: Dict[str, Any] = SimpleFrozenDict(),
force: bool = False,
@ -90,11 +99,18 @@ def project_run_parallel_group(
dry (bool): Perform a dry run and don't execute commands.
"""
config = load_project_config(project_dir, overrides=overrides)
commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
parallel_group_status_queue = MultiprocessingManager().Queue()
max_parallel_processes = config.get("max_parallel_processes")
check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION)
multiprocessing.set_start_method("spawn", force=True)
DISPLAY_STATUS = sys.stdout.isatty() and supports_ansi()
cmd_infos = [
_ParallelCommandInfo(cmd_name, commands[cmd_name], cmd_ind)
for cmd_ind, cmd_name in enumerate(cmd_names)
]
with working_dir(project_dir) as current_dir:
for cmd_info in cmd_infos:
check_deps(cmd_info.cmd, cmd_info.cmd_name, project_dir, dry)
@ -107,14 +123,14 @@ def project_run_parallel_group(
cmd_info.status = "not rerun"
rmtree(PARALLEL_LOGGING_DIR_NAME, ignore_errors=True)
mkdir(PARALLEL_LOGGING_DIR_NAME)
processes : List[Process] = []
proc_to_cmd_infos: Dict[Process, ParallelCommandInfo] = {}
processes: List[Process] = []
proc_to_cmd_infos: Dict[Process, _ParallelCommandInfo] = {}
num_processes = 0
for cmd_info in cmd_infos:
if cmd_info.status == "not rerun":
continue
process = Process(
target=project_run_parallel_cmd,
target=_project_run_parallel_cmd,
args=(cmd_info,),
kwargs={
"dry": dry,
@ -132,7 +148,7 @@ def project_run_parallel_group(
num_processes = max_parallel_processes
process_iterator = iter(processes)
for _ in range(num_processes):
start_process(next(process_iterator), proc_to_cmd_infos)
_start_process(next(process_iterator), proc_to_cmd_infos)
divider_parallel_descriptor = parallel_descriptor = (
"parallel[" + ", ".join(cmd_info.cmd_name for cmd_info in cmd_infos) + "]"
)
@ -157,10 +173,12 @@ def project_run_parallel_group(
other_cmd_info.status = "hung"
for other_cmd_info in (c for c in cmd_infos if c.status == "pending"):
other_cmd_info.status = "cancelled"
cmd_info = cmd_infos[mess["cmd_ind"]]
cmd_info = cmd_infos[cast(int, mess["cmd_ind"])]
if mess["status"] in ("started", "alive"):
cmd_info.last_status_time = int(time())
for other_cmd_info in (c for c in cmd_infos if c.status == "running"):
for other_cmd_info in (
c for c in cmd_infos if c.status in ("starting", "running")
):
if (
other_cmd_info.last_status_time is not None
and time() - other_cmd_info.last_status_time
@ -169,26 +187,26 @@ def project_run_parallel_group(
other_cmd_info.status = "hung"
if mess["status"] == "started":
cmd_info.status = "running"
cmd_info.os_cmd_ind = mess["os_cmd_ind"]
cmd_info.pid = mess["pid"]
cmd_info.os_cmd_ind = cast(int, mess["os_cmd_ind"])
cmd_info.pid = cast(int, mess["pid"])
if mess["status"] == "completed":
cmd_info.rc = mess["rc"]
cmd_info.rc = cast(int, mess["rc"])
if cmd_info.rc == 0:
cmd_info.status = "succeeded"
if not dry:
update_lockfile(current_dir, cmd_info.cmd)
working_process = next(process_iterator, None)
if working_process is not None:
start_process(working_process, proc_to_cmd_infos)
_start_process(working_process, proc_to_cmd_infos)
elif cmd_info.rc > 0:
cmd_info.status = "failed"
else:
cmd_info.status = "killed"
cmd_info.output = mess["output"]
cmd_info.output = cast(str, mess["output"])
if any(c for c in cmd_infos if c.status in ("failed", "killed", "hung")):
for other_cmd_info in (c for c in cmd_infos if c.status == "running"):
try:
kill(other_cmd_info.pid, SIGTERM)
kill(cast(int, other_cmd_info.pid), SIGTERM)
except:
pass
for other_cmd_info in (c for c in cmd_infos if c.status == "pending"):
@ -219,8 +237,8 @@ def project_run_parallel_group(
sys.exit(-1)
def project_run_parallel_cmd(
cmd_info: ParallelCommandInfo,
def _project_run_parallel_cmd(
cmd_info: _ParallelCommandInfo,
*,
dry: bool,
current_dir: str,
@ -318,4 +336,4 @@ def project_run_parallel_cmd(
"rc": rc,
"output": logfile.read(),
}
)
)

View File

@ -4,8 +4,7 @@ from wasabi import msg
from .remote_storage import RemoteStorage
from .remote_storage import get_command_hash
from .._util import project_cli, Arg, logger
from .._util import load_project_config
from .run import update_lockfile
from .._util import load_project_config, update_lockfile
@project_cli.command("pull")
@ -38,7 +37,6 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False):
# We use a while loop here because we don't know how the commands
# will be ordered. A command might need dependencies from one that's later
# in the list.
mult_group_mutex = multiprocessing.Lock()
while commands:
for i, cmd in enumerate(list(commands)):
logger.debug(f"CMD: {cmd['name']}.")
@ -54,7 +52,7 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False):
out_locs = [project_dir / out for out in cmd.get("outputs", [])]
if all(loc.exists() for loc in out_locs):
update_lockfile(project_dir, cmd, mult_group_mutex=mult_group_mutex)
update_lockfile(project_dir, cmd)
# We remove the command from the list here, and break, so that
# we iterate over the loop again.
commands.pop(i)

View File

@ -3,16 +3,15 @@ import sys
from pathlib import Path
from wasabi import msg
from wasabi.util import locale_escape
import srsly
import typer
from ... import about
from ...git_info import GIT_VERSION
from .parallel import project_run_parallel_group
from .._util import project_cli, Arg, Opt, COMMAND, parse_config_overrides, check_deps
from .._util import PROJECT_FILE, load_project_config, check_rerun, update_lockfile
from ...util import working_dir, run_command, split_command, is_cwd, join_command
from ...util import SimpleFrozenList, is_minor_version_match, ENV_VARS
from ...util import SimpleFrozenList, ENV_VARS
from ...util import check_bool_env_var, SimpleFrozenDict
from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash
from .._util import get_checksum, project_cli, Arg, Opt, COMMAND, parse_config_overrides
@project_cli.command(
@ -42,15 +41,6 @@ def project_run_cli(
project_run(project_dir, subcommand, overrides=overrides, force=force, dry=dry)
def check_deps(cmd: Dict, cmd_name: str, project_dir: Path, dry: bool):
for dep in cmd.get("deps", []):
if not (project_dir / dep).exists():
err = f"Missing dependency specified by command '{cmd_name}': {dep}"
err_help = "Maybe you forgot to run the 'project assets' command or a previous step?"
err_kwargs = {"exits": 1} if not dry else {}
msg.fail(err, err_help, **err_kwargs)
def project_run(
project_dir: Path,
subcommand: str,
@ -99,10 +89,7 @@ def project_run(
assert isinstance(cmds[0], str)
project_run_parallel_group(
project_dir,
[
ParallelCommandInfo(cmd, commands[cmd], cmd_ind)
for cmd_ind, cmd in enumerate(cmds)
],
cmds,
overrides=overrides,
force=force,
dry=dry,
@ -130,9 +117,6 @@ def project_run(
update_lockfile(current_dir, cmd)
def run_commands(
commands: Iterable[str] = SimpleFrozenList(),
silent: bool = False,
@ -250,42 +234,6 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
msg.table(table_entries)
def run_commands(
commands: Iterable[str] = SimpleFrozenList(),
silent: bool = False,
dry: bool = False,
capture: bool = False,
) -> None:
"""Run a sequence of commands in a subprocess, in order.
commands (List[str]): The string commands.
silent (bool): Don't print the commands.
dry (bool): Perform a dry run and don't execute anything.
capture (bool): Whether to capture the output and errors of individual commands.
If False, the stdout and stderr will not be redirected, and if there's an error,
sys.exit will be called with the return code. You should use capture=False
when you want to turn over execution to the command, and capture=True
when you want to run the command more like a function.
"""
for c in commands:
command = split_command(c)
# Not sure if this is needed or a good idea. Motivation: users may often
# use commands in their config that reference "python" and we want to
# make sure that it's always executing the same Python that spaCy is
# executed with and the pip in the same env, not some other Python/pip.
# Also ensures cross-compatibility if user 1 writes "python3" (because
# that's how it's set up on their system), and user 2 without the
# shortcut tries to re-run the command.
if len(command) and command[0] in ("python", "python3"):
command[0] = sys.executable
elif len(command) and command[0] in ("pip", "pip3"):
command = [sys.executable, "-m", "pip", *command[1:]]
if not silent:
print(f"Running command: {join_command(command)}")
if not dry:
run_command(command, capture=capture)
def validate_subcommand(
commands: Sequence[str], workflows: Sequence[str], subcommand: str
) -> None:
@ -307,110 +255,3 @@ def validate_subcommand(
". ".join(help_msg),
exits=1,
)
def check_rerun(
project_dir: Path,
command: Dict[str, Any],
*,
check_spacy_version: bool = True,
check_spacy_commit: bool = False,
) -> bool:
"""Check if a command should be rerun because its settings or inputs/outputs
changed.
project_dir (Path): The current project directory.
command (Dict[str, Any]): The command, as defined in the project.yml.
strict_version (bool):
RETURNS (bool): Whether to re-run the command.
"""
# Always rerun if no-skip is set
if command.get("no_skip", False):
return True
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists(): # We don't have a lockfile, run command
return True
data = srsly.read_yaml(lock_path)
if command["name"] not in data: # We don't have info about this command
return True
entry = data[command["name"]]
# Always run commands with no outputs (otherwise they'd always be skipped)
if not entry.get("outs", []):
return True
# Always rerun if spaCy version or commit hash changed
spacy_v = entry.get("spacy_version")
commit = entry.get("spacy_git_version")
if check_spacy_version and not is_minor_version_match(spacy_v, about.__version__):
info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)"
msg.info(f"Re-running '{command['name']}': spaCy minor version changed {info}")
return True
if check_spacy_commit and commit != GIT_VERSION:
info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)"
msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}")
return True
# If the entry in the lockfile matches the lockfile entry that would be
# generated from the current command, we don't rerun because it means that
# all inputs/outputs, hashes and scripts are the same and nothing changed
lock_entry = get_lock_entry(project_dir, command)
exclude = ["spacy_version", "spacy_git_version"]
return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude)
def update_lockfile(
project_dir: Path,
command: Dict[str, Any],
) -> None:
"""Update the lockfile after running a command. Will create a lockfile if
it doesn't yet exist and will add an entry for the current command, its
script and dependencies/outputs.
project_dir (Path): The current project directory.
command (Dict[str, Any]): The command, as defined in the project.yml.
"""
lock_path = project_dir / PROJECT_LOCK
if not lock_path.exists():
srsly.write_yaml(lock_path, {})
data = {}
else:
data = srsly.read_yaml(lock_path)
data[command["name"]] = get_lock_entry(project_dir, command)
srsly.write_yaml(lock_path, data)
def get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]:
"""Get a lockfile entry for a given command. An entry includes the command,
the script (command steps) and a list of dependencies and outputs with
their paths and file hashes, if available. The format is based on the
dvc.lock files, to keep things consistent.
project_dir (Path): The current project directory.
command (Dict[str, Any]): The command, as defined in the project.yml.
RETURNS (Dict[str, Any]): The lockfile entry.
"""
deps = get_fileinfo(project_dir, command.get("deps", []))
outs = get_fileinfo(project_dir, command.get("outputs", []))
outs_nc = get_fileinfo(project_dir, command.get("outputs_no_cache", []))
return {
"cmd": f"{COMMAND} run {command['name']}",
"script": command["script"],
"deps": deps,
"outs": [*outs, *outs_nc],
"spacy_version": about.__version__,
"spacy_git_version": GIT_VERSION,
}
def get_fileinfo(project_dir: Path, paths: List[str]) -> List[Dict[str, Optional[str]]]:
"""Generate the file information for a list of paths (dependencies, outputs).
Includes the file path and the file's checksum.
project_dir (Path): The current project directory.
paths (List[str]): The file paths.
RETURNS (List[Dict[str, str]]): The lockfile entry for a file.
"""
data = []
for path in paths:
file_path = project_dir / path
md5 = get_checksum(file_path) if file_path.exists() else None
data.append({"path": path, "md5": md5})
return data

View File

@ -19,7 +19,6 @@ from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
from spacy.cli.package import get_third_party_dependencies
from spacy.cli.package import _is_permitted_package_name
from spacy.cli.validate import get_model_pkgs
from spacy.cli.project.run import project_run
from spacy.lang.en import English
from spacy.lang.nl import Dutch
from spacy.language import Language
@ -422,170 +421,6 @@ def test_project_config_interpolation(int_value):
substitute_project_variables(project)
def test_project_config_multiprocessing_good_case():
project = {
"commands": [
{"name": "command1", "script": ["echo", "command1"]},
{"name": "command2", "script": ["echo", "command2"]},
{"name": "command3", "script": ["echo", "command3"]},
],
"workflows": {"all": ["command1", {"parallel": ["command2", "command3"]}]},
}
with make_tempdir() as d:
srsly.write_yaml(d / "project.yml", project)
load_project_config(d)
@pytest.mark.parametrize(
"workflows",
[
{"all": ["command1", {"parallel": ["command2"]}, "command3"]},
{"all": ["command1", {"parallel": ["command2", "command4"]}]},
{"all": ["command1", {"parallel": ["command2", "command2"]}]},
{"all": ["command1", {"serial": ["command2", "command3"]}]},
],
)
def test_project_config_multiprocessing_bad_case_workflows(workflows):
project = {
"commands": [
{"name": "command1", "script": ["echo", "command1"]},
{"name": "command2", "script": ["echo", "command2"]},
{"name": "command3", "script": ["echo", "command3"]},
],
"workflows": workflows,
}
with make_tempdir() as d:
srsly.write_yaml(d / "project.yml", project)
with pytest.raises(SystemExit):
load_project_config(d)
@pytest.mark.parametrize("max_parallel_processes", [-1, 0, 1])
def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_processes):
with make_tempdir() as d:
project = {
"max_parallel_processes": max_parallel_processes,
"commands": [
{
"name": "commandA",
"script": [" ".join(("touch", os.sep.join((str(d), "A"))))],
},
{
"name": "commandB",
"script": [" ".join(("touch", os.sep.join((str(d), "B"))))],
},
{
"name": "commandC",
"script": [" ".join(("touch", os.sep.join((str(d), "C"))))],
},
],
"workflows": {"all": [{"parallel": ["commandA", "commandB", "commandC"]}]},
}
with make_tempdir() as d:
srsly.write_yaml(d / "project.yml", project)
with pytest.raises(SystemExit):
load_project_config(d)
def test_project_run_multiprocessing_good_case():
with make_tempdir() as d:
pscript = """
import sys, os
from time import sleep
_, d_path, in_filename, out_filename, first_flag_filename, second_flag_filename = sys.argv
with open(os.sep.join((d_path, out_filename)), 'w') as out_file:
out_file.write("")
while True:
if os.path.exists(os.sep.join((d_path, in_filename))):
break
sleep(0.1)
if not os.path.exists(os.sep.join((d_path, first_flag_filename))):
with open(os.sep.join((d_path, first_flag_filename)), 'w') as first_flag_file:
first_flag_file.write("")
else: # should never happen because of skipping
with open(os.sep.join((d_path, second_flag_filename)), 'w') as second_flag_file:
second_flag_file.write("")
"""
pscript_loc = os.sep.join((str(d), "pscript.py"))
with open(pscript_loc, "w") as pscript_file:
pscript_file.write(pscript)
os.chmod(pscript_loc, 0o777)
project = {
"commands": [
{
"name": "commandA",
"script": [
" ".join(("python", pscript_loc, str(d), "a", "b", "c", "d"))
],
"outputs": [os.sep.join((str(d), "f"))],
},
{
"name": "commandB",
"script": [
" ".join(("python", pscript_loc, str(d), "b", "a", "e", "f"))
],
"outputs": [os.sep.join((str(d), "e"))],
},
],
"workflows": {
"all": [
{"parallel": ["commandA", "commandB"]},
{"parallel": ["commandB", "commandA"]},
]
},
}
srsly.write_yaml(d / "project.yml", project)
load_project_config(d)
project_run(d, "all")
assert os.path.exists(os.sep.join((str(d), "c")))
assert os.path.exists(os.sep.join((str(d), "e")))
assert not os.path.exists(os.sep.join((str(d), "d")))
assert not os.path.exists(os.sep.join((str(d), "f")))
@pytest.mark.parametrize("max_parallel_processes", [2, 3, 4, 5, 6])
def test_project_run_multiprocessing_max_processes_good_case(max_parallel_processes):
with make_tempdir() as d:
project = {
"max_parallel_processes": max_parallel_processes,
"commands": [
{
"name": "commandA",
"script": [" ".join(("touch", os.sep.join((str(d), "A"))))],
},
{
"name": "commandB",
"script": [" ".join(("touch", os.sep.join((str(d), "B"))))],
},
{
"name": "commandC",
"script": [" ".join(("touch", os.sep.join((str(d), "C"))))],
},
{
"name": "commandD",
"script": [" ".join(("touch", os.sep.join((str(d), "D"))))],
},
{
"name": "commandE",
"script": [" ".join(("touch", os.sep.join((str(d), "E"))))],
},
],
"workflows": {"all": [{"parallel": ["commandA", "commandB", "commandC", "commandD", "commandE"]}]},
}
srsly.write_yaml(d / "project.yml", project)
load_project_config(d)
project_run(d, "all")
assert os.path.exists(os.sep.join((str(d), "A")))
assert os.path.exists(os.sep.join((str(d), "B")))
assert os.path.exists(os.sep.join((str(d), "C")))
assert os.path.exists(os.sep.join((str(d), "D")))
assert os.path.exists(os.sep.join((str(d), "E")))
@pytest.mark.parametrize(
"greeting",
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],

View File

@ -0,0 +1,183 @@
import os
import pytest
import srsly
from spacy.cli._util import load_project_config
from spacy.cli.project.run import project_run
from .util import make_tempdir
def test_project_config_multiprocessing_good_case():
project = {
"commands": [
{"name": "command1", "script": ["echo", "command1"]},
{"name": "command2", "script": ["echo", "command2"]},
{"name": "command3", "script": ["echo", "command3"]},
],
"workflows": {"all": ["command1", {"parallel": ["command2", "command3"]}]},
}
with make_tempdir() as d:
srsly.write_yaml(d / "project.yml", project)
load_project_config(d)
@pytest.mark.parametrize(
"workflows",
[
{"all": ["command1", {"parallel": ["command2"]}, "command3"]},
{"all": ["command1", {"parallel": ["command2", "command4"]}]},
{"all": ["command1", {"parallel": ["command2", "command2"]}]},
{"all": ["command1", {"serial": ["command2", "command3"]}]},
],
)
def test_project_config_multiprocessing_bad_case_workflows(workflows):
project = {
"commands": [
{"name": "command1", "script": ["echo", "command1"]},
{"name": "command2", "script": ["echo", "command2"]},
{"name": "command3", "script": ["echo", "command3"]},
],
"workflows": workflows,
}
with make_tempdir() as d:
srsly.write_yaml(d / "project.yml", project)
with pytest.raises(SystemExit):
load_project_config(d)
@pytest.mark.parametrize("max_parallel_processes", [-1, 0, 1])
def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_processes):
with make_tempdir() as d:
project = {
"max_parallel_processes": max_parallel_processes,
"commands": [
{
"name": "commandA",
"script": [" ".join(("touch", os.sep.join((str(d), "A"))))],
},
{
"name": "commandB",
"script": [" ".join(("touch", os.sep.join((str(d), "B"))))],
},
{
"name": "commandC",
"script": [" ".join(("touch", os.sep.join((str(d), "C"))))],
},
],
"workflows": {"all": [{"parallel": ["commandA", "commandB", "commandC"]}]},
}
with make_tempdir() as d:
srsly.write_yaml(d / "project.yml", project)
with pytest.raises(SystemExit):
load_project_config(d)
def test_project_run_multiprocessing_good_case():
with make_tempdir() as d:
pscript = """
import sys, os
from time import sleep
_, d_path, in_filename, out_filename, first_flag_filename, second_flag_filename = sys.argv
with open(os.sep.join((d_path, out_filename)), 'w') as out_file:
out_file.write("")
while True:
if os.path.exists(os.sep.join((d_path, in_filename))):
break
sleep(0.1)
if not os.path.exists(os.sep.join((d_path, first_flag_filename))):
with open(os.sep.join((d_path, first_flag_filename)), 'w') as first_flag_file:
first_flag_file.write("")
else: # should never happen because of skipping
with open(os.sep.join((d_path, second_flag_filename)), 'w') as second_flag_file:
second_flag_file.write("")
"""
pscript_loc = os.sep.join((str(d), "pscript.py"))
with open(pscript_loc, "w") as pscript_file:
pscript_file.write(pscript)
os.chmod(pscript_loc, 0o777)
project = {
"commands": [
{
"name": "commandA",
"script": [
" ".join(("python", pscript_loc, str(d), "a", "b", "c", "d"))
],
"outputs": [os.sep.join((str(d), "f"))],
},
{
"name": "commandB",
"script": [
" ".join(("python", pscript_loc, str(d), "b", "a", "e", "f"))
],
"outputs": [os.sep.join((str(d), "e"))],
},
],
"workflows": {
"all": [
{"parallel": ["commandA", "commandB"]},
{"parallel": ["commandB", "commandA"]},
]
},
}
srsly.write_yaml(d / "project.yml", project)
load_project_config(d)
project_run(d, "all")
assert os.path.exists(os.sep.join((str(d), "c")))
assert os.path.exists(os.sep.join((str(d), "e")))
assert not os.path.exists(os.sep.join((str(d), "d")))
assert not os.path.exists(os.sep.join((str(d), "f")))
@pytest.mark.parametrize("max_parallel_processes", [2, 3, 4, 5, 6])
def test_project_run_multiprocessing_max_processes_good_case(max_parallel_processes):
with make_tempdir() as d:
project = {
"max_parallel_processes": max_parallel_processes,
"commands": [
{
"name": "commandA",
"script": [" ".join(("touch", os.sep.join((str(d), "A"))))],
},
{
"name": "commandB",
"script": [" ".join(("touch", os.sep.join((str(d), "B"))))],
},
{
"name": "commandC",
"script": [" ".join(("touch", os.sep.join((str(d), "C"))))],
},
{
"name": "commandD",
"script": [" ".join(("touch", os.sep.join((str(d), "D"))))],
},
{
"name": "commandE",
"script": [" ".join(("touch", os.sep.join((str(d), "E"))))],
},
],
"workflows": {
"all": [
{
"parallel": [
"commandA",
"commandB",
"commandC",
"commandD",
"commandE",
]
}
]
},
}
srsly.write_yaml(d / "project.yml", project)
load_project_config(d)
project_run(d, "all")
assert os.path.exists(os.sep.join((str(d), "A")))
assert os.path.exists(os.sep.join((str(d), "B")))
assert os.path.exists(os.sep.join((str(d), "C")))
assert os.path.exists(os.sep.join((str(d), "D")))
assert os.path.exists(os.sep.join((str(d), "E")))