mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-09 22:54:53 +03:00
Refactoring into separate module
This commit is contained in:
parent
e9ee680873
commit
4c2fc56a5b
|
@ -18,8 +18,12 @@ import os
|
|||
|
||||
from ..compat import Literal
|
||||
from ..schemas import ProjectConfigSchema, validate
|
||||
from ..git_info import GIT_VERSION
|
||||
from ..util import import_file, run_command, make_tempdir, registry, logger
|
||||
from ..util import is_compatible_version, SimpleFrozenDict, ENV_VARS
|
||||
from ..util import is_minor_version_match
|
||||
|
||||
|
||||
from .. import about
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -598,3 +602,121 @@ def setup_gpu(use_gpu: int, silent=None) -> None:
|
|||
local_msg.info("Using CPU")
|
||||
if has_cupy and gpu_is_available():
|
||||
local_msg.info("To switch to GPU 0, use the option: --gpu-id 0")
|
||||
|
||||
|
||||
def check_rerun(
|
||||
project_dir: Path,
|
||||
command: Dict[str, Any],
|
||||
*,
|
||||
check_spacy_version: bool = True,
|
||||
check_spacy_commit: bool = False,
|
||||
) -> bool:
|
||||
"""Check if a command should be rerun because its settings or inputs/outputs
|
||||
changed.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
command (Dict[str, Any]): The command, as defined in the project.yml.
|
||||
strict_version (bool):
|
||||
RETURNS (bool): Whether to re-run the command.
|
||||
"""
|
||||
# Always rerun if no-skip is set
|
||||
if command.get("no_skip", False):
|
||||
return True
|
||||
lock_path = project_dir / PROJECT_LOCK
|
||||
if not lock_path.exists(): # We don't have a lockfile, run command
|
||||
return True
|
||||
data = srsly.read_yaml(lock_path)
|
||||
if command["name"] not in data: # We don't have info about this command
|
||||
return True
|
||||
entry = data[command["name"]]
|
||||
# Always run commands with no outputs (otherwise they'd always be skipped)
|
||||
if not entry.get("outs", []):
|
||||
return True
|
||||
# Always rerun if spaCy version or commit hash changed
|
||||
spacy_v = entry.get("spacy_version")
|
||||
commit = entry.get("spacy_git_version")
|
||||
if check_spacy_version and not is_minor_version_match(spacy_v, about.__version__):
|
||||
info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)"
|
||||
msg.info(f"Re-running '{command['name']}': spaCy minor version changed {info}")
|
||||
return True
|
||||
if check_spacy_commit and commit != GIT_VERSION:
|
||||
info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)"
|
||||
msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}")
|
||||
return True
|
||||
# If the entry in the lockfile matches the lockfile entry that would be
|
||||
# generated from the current command, we don't rerun because it means that
|
||||
# all inputs/outputs, hashes and scripts are the same and nothing changed
|
||||
lock_entry = _get_lock_entry(project_dir, command)
|
||||
exclude = ["spacy_version", "spacy_git_version"]
|
||||
return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude)
|
||||
|
||||
|
||||
def update_lockfile(
|
||||
project_dir: Path,
|
||||
command: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Update the lockfile after running a command. Will create a lockfile if
|
||||
it doesn't yet exist and will add an entry for the current command, its
|
||||
script and dependencies/outputs.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
command (Dict[str, Any]): The command, as defined in the project.yml.
|
||||
"""
|
||||
lock_path = project_dir / PROJECT_LOCK
|
||||
if not lock_path.exists():
|
||||
srsly.write_yaml(lock_path, {})
|
||||
data = {}
|
||||
else:
|
||||
data = srsly.read_yaml(lock_path)
|
||||
data[command["name"]] = _get_lock_entry(project_dir, command)
|
||||
srsly.write_yaml(lock_path, data)
|
||||
|
||||
|
||||
def _get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get a lockfile entry for a given command. An entry includes the command,
|
||||
the script (command steps) and a list of dependencies and outputs with
|
||||
their paths and file hashes, if available. The format is based on the
|
||||
dvc.lock files, to keep things consistent.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
command (Dict[str, Any]): The command, as defined in the project.yml.
|
||||
RETURNS (Dict[str, Any]): The lockfile entry.
|
||||
"""
|
||||
deps = _get_fileinfo(project_dir, command.get("deps", []))
|
||||
outs = _get_fileinfo(project_dir, command.get("outputs", []))
|
||||
outs_nc = _get_fileinfo(project_dir, command.get("outputs_no_cache", []))
|
||||
return {
|
||||
"cmd": f"{COMMAND} run {command['name']}",
|
||||
"script": command["script"],
|
||||
"deps": deps,
|
||||
"outs": [*outs, *outs_nc],
|
||||
"spacy_version": about.__version__,
|
||||
"spacy_git_version": GIT_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def _get_fileinfo(
|
||||
project_dir: Path, paths: List[str]
|
||||
) -> List[Dict[str, Optional[str]]]:
|
||||
"""Generate the file information for a list of paths (dependencies, outputs).
|
||||
Includes the file path and the file's checksum.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
paths (List[str]): The file paths.
|
||||
RETURNS (List[Dict[str, str]]): The lockfile entry for a file.
|
||||
"""
|
||||
data = []
|
||||
for path in paths:
|
||||
file_path = project_dir / path
|
||||
md5 = get_checksum(file_path) if file_path.exists() else None
|
||||
data.append({"path": path, "md5": md5})
|
||||
return data
|
||||
|
||||
|
||||
def check_deps(cmd: Dict, cmd_name: str, project_dir: Path, dry: bool):
|
||||
for dep in cmd.get("deps", []):
|
||||
if not (project_dir / dep).exists():
|
||||
err = f"Missing dependency specified by command '{cmd_name}': {dep}"
|
||||
err_help = "Maybe you forgot to run the 'project assets' command or a previous step?"
|
||||
err_kwargs = {"exits": 1} if not dry else {}
|
||||
msg.fail(err, err_help, **err_kwargs)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Any, List, Literal, Optional, Dict
|
||||
from typing import Any, List, Literal, Optional, Dict, Union, cast
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from time import time
|
||||
import multiprocessing
|
||||
|
@ -7,14 +8,15 @@ from os import kill, environ, linesep, mkdir, sep
|
|||
from shutil import rmtree
|
||||
from signal import SIGTERM
|
||||
from subprocess import STDOUT, Popen, TimeoutExpired
|
||||
from wasabi import msg
|
||||
from wasabi.util import color, supports_ansi
|
||||
|
||||
from ...util import SimpleFrozenDict, load_project_config, working_dir
|
||||
from ...util import check_bool_env_var
|
||||
from .._util import check_rerun, check_deps, update_lockfile, load_project_config
|
||||
from ...util import SimpleFrozenDict, working_dir
|
||||
from ...util import check_bool_env_var, ENV_VARS, split_command, join_command
|
||||
from ...errors import Errors
|
||||
|
||||
|
||||
|
||||
# How often the worker processes managing the commands in a parallel group
|
||||
# send keepalive messages to the main processes
|
||||
PARALLEL_GROUP_STATUS_INTERVAL = 1
|
||||
|
@ -34,8 +36,10 @@ DISPLAY_STATUS_COLORS = {
|
|||
"hung": "red",
|
||||
"cancelled": "red",
|
||||
}
|
||||
class ParallelCommandInfo:
|
||||
def __init__(self, cmd_name: str, cmd: str, cmd_ind: int):
|
||||
|
||||
|
||||
class _ParallelCommandInfo:
|
||||
def __init__(self, cmd_name: str, cmd: Dict, cmd_ind: int):
|
||||
self.cmd_name = cmd_name
|
||||
self.cmd = cmd
|
||||
self.cmd_ind = cmd_ind
|
||||
|
@ -59,21 +63,26 @@ class ParallelCommandInfo:
|
|||
|
||||
@property
|
||||
def disp_status(self) -> str:
|
||||
status = self.status
|
||||
status_color = DISPLAY_STATUS_COLORS[status]
|
||||
if status == "running":
|
||||
status = f"{status} ({self.os_cmd_ind + 1}/{self.len_os_cmds})"
|
||||
return color(status, status_color)
|
||||
status_str = str(self.status)
|
||||
status_color = DISPLAY_STATUS_COLORS[status_str]
|
||||
if status_str == "running" and self.os_cmd_ind is not None:
|
||||
status_str = f"{status_str} ({self.os_cmd_ind + 1}/{self.len_os_cmds})"
|
||||
return color(status_str, status_color)
|
||||
|
||||
def start_process(process:Process, proc_to_cmd_infos: Dict[Process, ParallelCommandInfo]) -> None:
|
||||
|
||||
def _start_process(
|
||||
process: Process, proc_to_cmd_infos: Dict[Process, _ParallelCommandInfo]
|
||||
) -> None:
|
||||
cmd_info = proc_to_cmd_infos[process]
|
||||
if cmd_info.status == "pending":
|
||||
cmd_info.status = "starting"
|
||||
cmd_info.last_status_time = int(time())
|
||||
process.start()
|
||||
|
||||
|
||||
def project_run_parallel_group(
|
||||
project_dir: Path,
|
||||
cmd_infos: List[ParallelCommandInfo],
|
||||
cmd_names: List[str],
|
||||
*,
|
||||
overrides: Dict[str, Any] = SimpleFrozenDict(),
|
||||
force: bool = False,
|
||||
|
@ -90,11 +99,18 @@ def project_run_parallel_group(
|
|||
dry (bool): Perform a dry run and don't execute commands.
|
||||
"""
|
||||
config = load_project_config(project_dir, overrides=overrides)
|
||||
commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
|
||||
|
||||
parallel_group_status_queue = MultiprocessingManager().Queue()
|
||||
max_parallel_processes = config.get("max_parallel_processes")
|
||||
check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION)
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
DISPLAY_STATUS = sys.stdout.isatty() and supports_ansi()
|
||||
cmd_infos = [
|
||||
_ParallelCommandInfo(cmd_name, commands[cmd_name], cmd_ind)
|
||||
for cmd_ind, cmd_name in enumerate(cmd_names)
|
||||
]
|
||||
|
||||
with working_dir(project_dir) as current_dir:
|
||||
for cmd_info in cmd_infos:
|
||||
check_deps(cmd_info.cmd, cmd_info.cmd_name, project_dir, dry)
|
||||
|
@ -107,14 +123,14 @@ def project_run_parallel_group(
|
|||
cmd_info.status = "not rerun"
|
||||
rmtree(PARALLEL_LOGGING_DIR_NAME, ignore_errors=True)
|
||||
mkdir(PARALLEL_LOGGING_DIR_NAME)
|
||||
processes : List[Process] = []
|
||||
proc_to_cmd_infos: Dict[Process, ParallelCommandInfo] = {}
|
||||
processes: List[Process] = []
|
||||
proc_to_cmd_infos: Dict[Process, _ParallelCommandInfo] = {}
|
||||
num_processes = 0
|
||||
for cmd_info in cmd_infos:
|
||||
if cmd_info.status == "not rerun":
|
||||
continue
|
||||
process = Process(
|
||||
target=project_run_parallel_cmd,
|
||||
target=_project_run_parallel_cmd,
|
||||
args=(cmd_info,),
|
||||
kwargs={
|
||||
"dry": dry,
|
||||
|
@ -132,7 +148,7 @@ def project_run_parallel_group(
|
|||
num_processes = max_parallel_processes
|
||||
process_iterator = iter(processes)
|
||||
for _ in range(num_processes):
|
||||
start_process(next(process_iterator), proc_to_cmd_infos)
|
||||
_start_process(next(process_iterator), proc_to_cmd_infos)
|
||||
divider_parallel_descriptor = parallel_descriptor = (
|
||||
"parallel[" + ", ".join(cmd_info.cmd_name for cmd_info in cmd_infos) + "]"
|
||||
)
|
||||
|
@ -157,10 +173,12 @@ def project_run_parallel_group(
|
|||
other_cmd_info.status = "hung"
|
||||
for other_cmd_info in (c for c in cmd_infos if c.status == "pending"):
|
||||
other_cmd_info.status = "cancelled"
|
||||
cmd_info = cmd_infos[mess["cmd_ind"]]
|
||||
cmd_info = cmd_infos[cast(int, mess["cmd_ind"])]
|
||||
if mess["status"] in ("started", "alive"):
|
||||
cmd_info.last_status_time = int(time())
|
||||
for other_cmd_info in (c for c in cmd_infos if c.status == "running"):
|
||||
for other_cmd_info in (
|
||||
c for c in cmd_infos if c.status in ("starting", "running")
|
||||
):
|
||||
if (
|
||||
other_cmd_info.last_status_time is not None
|
||||
and time() - other_cmd_info.last_status_time
|
||||
|
@ -169,26 +187,26 @@ def project_run_parallel_group(
|
|||
other_cmd_info.status = "hung"
|
||||
if mess["status"] == "started":
|
||||
cmd_info.status = "running"
|
||||
cmd_info.os_cmd_ind = mess["os_cmd_ind"]
|
||||
cmd_info.pid = mess["pid"]
|
||||
cmd_info.os_cmd_ind = cast(int, mess["os_cmd_ind"])
|
||||
cmd_info.pid = cast(int, mess["pid"])
|
||||
if mess["status"] == "completed":
|
||||
cmd_info.rc = mess["rc"]
|
||||
cmd_info.rc = cast(int, mess["rc"])
|
||||
if cmd_info.rc == 0:
|
||||
cmd_info.status = "succeeded"
|
||||
if not dry:
|
||||
update_lockfile(current_dir, cmd_info.cmd)
|
||||
working_process = next(process_iterator, None)
|
||||
if working_process is not None:
|
||||
start_process(working_process, proc_to_cmd_infos)
|
||||
_start_process(working_process, proc_to_cmd_infos)
|
||||
elif cmd_info.rc > 0:
|
||||
cmd_info.status = "failed"
|
||||
else:
|
||||
cmd_info.status = "killed"
|
||||
cmd_info.output = mess["output"]
|
||||
cmd_info.output = cast(str, mess["output"])
|
||||
if any(c for c in cmd_infos if c.status in ("failed", "killed", "hung")):
|
||||
for other_cmd_info in (c for c in cmd_infos if c.status == "running"):
|
||||
try:
|
||||
kill(other_cmd_info.pid, SIGTERM)
|
||||
kill(cast(int, other_cmd_info.pid), SIGTERM)
|
||||
except:
|
||||
pass
|
||||
for other_cmd_info in (c for c in cmd_infos if c.status == "pending"):
|
||||
|
@ -219,8 +237,8 @@ def project_run_parallel_group(
|
|||
sys.exit(-1)
|
||||
|
||||
|
||||
def project_run_parallel_cmd(
|
||||
cmd_info: ParallelCommandInfo,
|
||||
def _project_run_parallel_cmd(
|
||||
cmd_info: _ParallelCommandInfo,
|
||||
*,
|
||||
dry: bool,
|
||||
current_dir: str,
|
||||
|
@ -318,4 +336,4 @@ def project_run_parallel_cmd(
|
|||
"rc": rc,
|
||||
"output": logfile.read(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
|
|
@ -4,8 +4,7 @@ from wasabi import msg
|
|||
from .remote_storage import RemoteStorage
|
||||
from .remote_storage import get_command_hash
|
||||
from .._util import project_cli, Arg, logger
|
||||
from .._util import load_project_config
|
||||
from .run import update_lockfile
|
||||
from .._util import load_project_config, update_lockfile
|
||||
|
||||
|
||||
@project_cli.command("pull")
|
||||
|
@ -38,7 +37,6 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False):
|
|||
# We use a while loop here because we don't know how the commands
|
||||
# will be ordered. A command might need dependencies from one that's later
|
||||
# in the list.
|
||||
mult_group_mutex = multiprocessing.Lock()
|
||||
while commands:
|
||||
for i, cmd in enumerate(list(commands)):
|
||||
logger.debug(f"CMD: {cmd['name']}.")
|
||||
|
@ -54,7 +52,7 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False):
|
|||
|
||||
out_locs = [project_dir / out for out in cmd.get("outputs", [])]
|
||||
if all(loc.exists() for loc in out_locs):
|
||||
update_lockfile(project_dir, cmd, mult_group_mutex=mult_group_mutex)
|
||||
update_lockfile(project_dir, cmd)
|
||||
# We remove the command from the list here, and break, so that
|
||||
# we iterate over the loop again.
|
||||
commands.pop(i)
|
||||
|
|
|
@ -3,16 +3,15 @@ import sys
|
|||
from pathlib import Path
|
||||
from wasabi import msg
|
||||
from wasabi.util import locale_escape
|
||||
import srsly
|
||||
import typer
|
||||
|
||||
from ... import about
|
||||
from ...git_info import GIT_VERSION
|
||||
|
||||
from .parallel import project_run_parallel_group
|
||||
from .._util import project_cli, Arg, Opt, COMMAND, parse_config_overrides, check_deps
|
||||
from .._util import PROJECT_FILE, load_project_config, check_rerun, update_lockfile
|
||||
from ...util import working_dir, run_command, split_command, is_cwd, join_command
|
||||
from ...util import SimpleFrozenList, is_minor_version_match, ENV_VARS
|
||||
from ...util import SimpleFrozenList, ENV_VARS
|
||||
from ...util import check_bool_env_var, SimpleFrozenDict
|
||||
from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash
|
||||
from .._util import get_checksum, project_cli, Arg, Opt, COMMAND, parse_config_overrides
|
||||
|
||||
|
||||
@project_cli.command(
|
||||
|
@ -42,15 +41,6 @@ def project_run_cli(
|
|||
project_run(project_dir, subcommand, overrides=overrides, force=force, dry=dry)
|
||||
|
||||
|
||||
def check_deps(cmd: Dict, cmd_name: str, project_dir: Path, dry: bool):
|
||||
for dep in cmd.get("deps", []):
|
||||
if not (project_dir / dep).exists():
|
||||
err = f"Missing dependency specified by command '{cmd_name}': {dep}"
|
||||
err_help = "Maybe you forgot to run the 'project assets' command or a previous step?"
|
||||
err_kwargs = {"exits": 1} if not dry else {}
|
||||
msg.fail(err, err_help, **err_kwargs)
|
||||
|
||||
|
||||
def project_run(
|
||||
project_dir: Path,
|
||||
subcommand: str,
|
||||
|
@ -99,10 +89,7 @@ def project_run(
|
|||
assert isinstance(cmds[0], str)
|
||||
project_run_parallel_group(
|
||||
project_dir,
|
||||
[
|
||||
ParallelCommandInfo(cmd, commands[cmd], cmd_ind)
|
||||
for cmd_ind, cmd in enumerate(cmds)
|
||||
],
|
||||
cmds,
|
||||
overrides=overrides,
|
||||
force=force,
|
||||
dry=dry,
|
||||
|
@ -130,9 +117,6 @@ def project_run(
|
|||
update_lockfile(current_dir, cmd)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def run_commands(
|
||||
commands: Iterable[str] = SimpleFrozenList(),
|
||||
silent: bool = False,
|
||||
|
@ -250,42 +234,6 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
|
|||
msg.table(table_entries)
|
||||
|
||||
|
||||
def run_commands(
|
||||
commands: Iterable[str] = SimpleFrozenList(),
|
||||
silent: bool = False,
|
||||
dry: bool = False,
|
||||
capture: bool = False,
|
||||
) -> None:
|
||||
"""Run a sequence of commands in a subprocess, in order.
|
||||
|
||||
commands (List[str]): The string commands.
|
||||
silent (bool): Don't print the commands.
|
||||
dry (bool): Perform a dry run and don't execute anything.
|
||||
capture (bool): Whether to capture the output and errors of individual commands.
|
||||
If False, the stdout and stderr will not be redirected, and if there's an error,
|
||||
sys.exit will be called with the return code. You should use capture=False
|
||||
when you want to turn over execution to the command, and capture=True
|
||||
when you want to run the command more like a function.
|
||||
"""
|
||||
for c in commands:
|
||||
command = split_command(c)
|
||||
# Not sure if this is needed or a good idea. Motivation: users may often
|
||||
# use commands in their config that reference "python" and we want to
|
||||
# make sure that it's always executing the same Python that spaCy is
|
||||
# executed with and the pip in the same env, not some other Python/pip.
|
||||
# Also ensures cross-compatibility if user 1 writes "python3" (because
|
||||
# that's how it's set up on their system), and user 2 without the
|
||||
# shortcut tries to re-run the command.
|
||||
if len(command) and command[0] in ("python", "python3"):
|
||||
command[0] = sys.executable
|
||||
elif len(command) and command[0] in ("pip", "pip3"):
|
||||
command = [sys.executable, "-m", "pip", *command[1:]]
|
||||
if not silent:
|
||||
print(f"Running command: {join_command(command)}")
|
||||
if not dry:
|
||||
run_command(command, capture=capture)
|
||||
|
||||
|
||||
def validate_subcommand(
|
||||
commands: Sequence[str], workflows: Sequence[str], subcommand: str
|
||||
) -> None:
|
||||
|
@ -307,110 +255,3 @@ def validate_subcommand(
|
|||
". ".join(help_msg),
|
||||
exits=1,
|
||||
)
|
||||
|
||||
|
||||
def check_rerun(
|
||||
project_dir: Path,
|
||||
command: Dict[str, Any],
|
||||
*,
|
||||
check_spacy_version: bool = True,
|
||||
check_spacy_commit: bool = False,
|
||||
) -> bool:
|
||||
"""Check if a command should be rerun because its settings or inputs/outputs
|
||||
changed.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
command (Dict[str, Any]): The command, as defined in the project.yml.
|
||||
strict_version (bool):
|
||||
RETURNS (bool): Whether to re-run the command.
|
||||
"""
|
||||
# Always rerun if no-skip is set
|
||||
if command.get("no_skip", False):
|
||||
return True
|
||||
lock_path = project_dir / PROJECT_LOCK
|
||||
if not lock_path.exists(): # We don't have a lockfile, run command
|
||||
return True
|
||||
data = srsly.read_yaml(lock_path)
|
||||
if command["name"] not in data: # We don't have info about this command
|
||||
return True
|
||||
entry = data[command["name"]]
|
||||
# Always run commands with no outputs (otherwise they'd always be skipped)
|
||||
if not entry.get("outs", []):
|
||||
return True
|
||||
# Always rerun if spaCy version or commit hash changed
|
||||
spacy_v = entry.get("spacy_version")
|
||||
commit = entry.get("spacy_git_version")
|
||||
if check_spacy_version and not is_minor_version_match(spacy_v, about.__version__):
|
||||
info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)"
|
||||
msg.info(f"Re-running '{command['name']}': spaCy minor version changed {info}")
|
||||
return True
|
||||
if check_spacy_commit and commit != GIT_VERSION:
|
||||
info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)"
|
||||
msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}")
|
||||
return True
|
||||
# If the entry in the lockfile matches the lockfile entry that would be
|
||||
# generated from the current command, we don't rerun because it means that
|
||||
# all inputs/outputs, hashes and scripts are the same and nothing changed
|
||||
lock_entry = get_lock_entry(project_dir, command)
|
||||
exclude = ["spacy_version", "spacy_git_version"]
|
||||
return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude)
|
||||
|
||||
|
||||
def update_lockfile(
|
||||
project_dir: Path,
|
||||
command: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Update the lockfile after running a command. Will create a lockfile if
|
||||
it doesn't yet exist and will add an entry for the current command, its
|
||||
script and dependencies/outputs.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
command (Dict[str, Any]): The command, as defined in the project.yml.
|
||||
"""
|
||||
lock_path = project_dir / PROJECT_LOCK
|
||||
if not lock_path.exists():
|
||||
srsly.write_yaml(lock_path, {})
|
||||
data = {}
|
||||
else:
|
||||
data = srsly.read_yaml(lock_path)
|
||||
data[command["name"]] = get_lock_entry(project_dir, command)
|
||||
srsly.write_yaml(lock_path, data)
|
||||
|
||||
|
||||
def get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get a lockfile entry for a given command. An entry includes the command,
|
||||
the script (command steps) and a list of dependencies and outputs with
|
||||
their paths and file hashes, if available. The format is based on the
|
||||
dvc.lock files, to keep things consistent.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
command (Dict[str, Any]): The command, as defined in the project.yml.
|
||||
RETURNS (Dict[str, Any]): The lockfile entry.
|
||||
"""
|
||||
deps = get_fileinfo(project_dir, command.get("deps", []))
|
||||
outs = get_fileinfo(project_dir, command.get("outputs", []))
|
||||
outs_nc = get_fileinfo(project_dir, command.get("outputs_no_cache", []))
|
||||
return {
|
||||
"cmd": f"{COMMAND} run {command['name']}",
|
||||
"script": command["script"],
|
||||
"deps": deps,
|
||||
"outs": [*outs, *outs_nc],
|
||||
"spacy_version": about.__version__,
|
||||
"spacy_git_version": GIT_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def get_fileinfo(project_dir: Path, paths: List[str]) -> List[Dict[str, Optional[str]]]:
|
||||
"""Generate the file information for a list of paths (dependencies, outputs).
|
||||
Includes the file path and the file's checksum.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
paths (List[str]): The file paths.
|
||||
RETURNS (List[Dict[str, str]]): The lockfile entry for a file.
|
||||
"""
|
||||
data = []
|
||||
for path in paths:
|
||||
file_path = project_dir / path
|
||||
md5 = get_checksum(file_path) if file_path.exists() else None
|
||||
data.append({"path": path, "md5": md5})
|
||||
return data
|
||||
|
|
|
@ -19,7 +19,6 @@ from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
|
|||
from spacy.cli.package import get_third_party_dependencies
|
||||
from spacy.cli.package import _is_permitted_package_name
|
||||
from spacy.cli.validate import get_model_pkgs
|
||||
from spacy.cli.project.run import project_run
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.nl import Dutch
|
||||
from spacy.language import Language
|
||||
|
@ -422,170 +421,6 @@ def test_project_config_interpolation(int_value):
|
|||
substitute_project_variables(project)
|
||||
|
||||
|
||||
def test_project_config_multiprocessing_good_case():
|
||||
project = {
|
||||
"commands": [
|
||||
{"name": "command1", "script": ["echo", "command1"]},
|
||||
{"name": "command2", "script": ["echo", "command2"]},
|
||||
{"name": "command3", "script": ["echo", "command3"]},
|
||||
],
|
||||
"workflows": {"all": ["command1", {"parallel": ["command2", "command3"]}]},
|
||||
}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
load_project_config(d)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"workflows",
|
||||
[
|
||||
{"all": ["command1", {"parallel": ["command2"]}, "command3"]},
|
||||
{"all": ["command1", {"parallel": ["command2", "command4"]}]},
|
||||
{"all": ["command1", {"parallel": ["command2", "command2"]}]},
|
||||
{"all": ["command1", {"serial": ["command2", "command3"]}]},
|
||||
],
|
||||
)
|
||||
def test_project_config_multiprocessing_bad_case_workflows(workflows):
|
||||
project = {
|
||||
"commands": [
|
||||
{"name": "command1", "script": ["echo", "command1"]},
|
||||
{"name": "command2", "script": ["echo", "command2"]},
|
||||
{"name": "command3", "script": ["echo", "command3"]},
|
||||
],
|
||||
"workflows": workflows,
|
||||
}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
with pytest.raises(SystemExit):
|
||||
load_project_config(d)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_parallel_processes", [-1, 0, 1])
|
||||
def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_processes):
|
||||
with make_tempdir() as d:
|
||||
project = {
|
||||
"max_parallel_processes": max_parallel_processes,
|
||||
"commands": [
|
||||
{
|
||||
"name": "commandA",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "A"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandB",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "B"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandC",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "C"))))],
|
||||
},
|
||||
],
|
||||
"workflows": {"all": [{"parallel": ["commandA", "commandB", "commandC"]}]},
|
||||
}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
with pytest.raises(SystemExit):
|
||||
load_project_config(d)
|
||||
|
||||
|
||||
def test_project_run_multiprocessing_good_case():
|
||||
with make_tempdir() as d:
|
||||
|
||||
pscript = """
|
||||
import sys, os
|
||||
from time import sleep
|
||||
|
||||
_, d_path, in_filename, out_filename, first_flag_filename, second_flag_filename = sys.argv
|
||||
with open(os.sep.join((d_path, out_filename)), 'w') as out_file:
|
||||
out_file.write("")
|
||||
while True:
|
||||
if os.path.exists(os.sep.join((d_path, in_filename))):
|
||||
break
|
||||
sleep(0.1)
|
||||
if not os.path.exists(os.sep.join((d_path, first_flag_filename))):
|
||||
with open(os.sep.join((d_path, first_flag_filename)), 'w') as first_flag_file:
|
||||
first_flag_file.write("")
|
||||
else: # should never happen because of skipping
|
||||
with open(os.sep.join((d_path, second_flag_filename)), 'w') as second_flag_file:
|
||||
second_flag_file.write("")
|
||||
"""
|
||||
|
||||
pscript_loc = os.sep.join((str(d), "pscript.py"))
|
||||
with open(pscript_loc, "w") as pscript_file:
|
||||
pscript_file.write(pscript)
|
||||
os.chmod(pscript_loc, 0o777)
|
||||
project = {
|
||||
"commands": [
|
||||
{
|
||||
"name": "commandA",
|
||||
"script": [
|
||||
" ".join(("python", pscript_loc, str(d), "a", "b", "c", "d"))
|
||||
],
|
||||
"outputs": [os.sep.join((str(d), "f"))],
|
||||
},
|
||||
{
|
||||
"name": "commandB",
|
||||
"script": [
|
||||
" ".join(("python", pscript_loc, str(d), "b", "a", "e", "f"))
|
||||
],
|
||||
"outputs": [os.sep.join((str(d), "e"))],
|
||||
},
|
||||
],
|
||||
"workflows": {
|
||||
"all": [
|
||||
{"parallel": ["commandA", "commandB"]},
|
||||
{"parallel": ["commandB", "commandA"]},
|
||||
]
|
||||
},
|
||||
}
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
load_project_config(d)
|
||||
project_run(d, "all")
|
||||
assert os.path.exists(os.sep.join((str(d), "c")))
|
||||
assert os.path.exists(os.sep.join((str(d), "e")))
|
||||
assert not os.path.exists(os.sep.join((str(d), "d")))
|
||||
assert not os.path.exists(os.sep.join((str(d), "f")))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_parallel_processes", [2, 3, 4, 5, 6])
|
||||
def test_project_run_multiprocessing_max_processes_good_case(max_parallel_processes):
|
||||
with make_tempdir() as d:
|
||||
|
||||
project = {
|
||||
"max_parallel_processes": max_parallel_processes,
|
||||
"commands": [
|
||||
{
|
||||
"name": "commandA",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "A"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandB",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "B"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandC",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "C"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandD",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "D"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandE",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "E"))))],
|
||||
},
|
||||
],
|
||||
"workflows": {"all": [{"parallel": ["commandA", "commandB", "commandC", "commandD", "commandE"]}]},
|
||||
}
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
load_project_config(d)
|
||||
project_run(d, "all")
|
||||
assert os.path.exists(os.sep.join((str(d), "A")))
|
||||
assert os.path.exists(os.sep.join((str(d), "B")))
|
||||
assert os.path.exists(os.sep.join((str(d), "C")))
|
||||
assert os.path.exists(os.sep.join((str(d), "D")))
|
||||
assert os.path.exists(os.sep.join((str(d), "E")))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"greeting",
|
||||
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],
|
||||
|
|
183
spacy/tests/test_parallel.py
Normal file
183
spacy/tests/test_parallel.py
Normal file
|
@ -0,0 +1,183 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import srsly
|
||||
from spacy.cli._util import load_project_config
|
||||
from spacy.cli.project.run import project_run
|
||||
from .util import make_tempdir
|
||||
|
||||
|
||||
def test_project_config_multiprocessing_good_case():
|
||||
project = {
|
||||
"commands": [
|
||||
{"name": "command1", "script": ["echo", "command1"]},
|
||||
{"name": "command2", "script": ["echo", "command2"]},
|
||||
{"name": "command3", "script": ["echo", "command3"]},
|
||||
],
|
||||
"workflows": {"all": ["command1", {"parallel": ["command2", "command3"]}]},
|
||||
}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
load_project_config(d)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"workflows",
|
||||
[
|
||||
{"all": ["command1", {"parallel": ["command2"]}, "command3"]},
|
||||
{"all": ["command1", {"parallel": ["command2", "command4"]}]},
|
||||
{"all": ["command1", {"parallel": ["command2", "command2"]}]},
|
||||
{"all": ["command1", {"serial": ["command2", "command3"]}]},
|
||||
],
|
||||
)
|
||||
def test_project_config_multiprocessing_bad_case_workflows(workflows):
|
||||
project = {
|
||||
"commands": [
|
||||
{"name": "command1", "script": ["echo", "command1"]},
|
||||
{"name": "command2", "script": ["echo", "command2"]},
|
||||
{"name": "command3", "script": ["echo", "command3"]},
|
||||
],
|
||||
"workflows": workflows,
|
||||
}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
with pytest.raises(SystemExit):
|
||||
load_project_config(d)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_parallel_processes", [-1, 0, 1])
|
||||
def test_project_config_multiprocessing_max_processes_bad_case(max_parallel_processes):
|
||||
with make_tempdir() as d:
|
||||
project = {
|
||||
"max_parallel_processes": max_parallel_processes,
|
||||
"commands": [
|
||||
{
|
||||
"name": "commandA",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "A"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandB",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "B"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandC",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "C"))))],
|
||||
},
|
||||
],
|
||||
"workflows": {"all": [{"parallel": ["commandA", "commandB", "commandC"]}]},
|
||||
}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
with pytest.raises(SystemExit):
|
||||
load_project_config(d)
|
||||
|
||||
|
||||
def test_project_run_multiprocessing_good_case():
|
||||
with make_tempdir() as d:
|
||||
|
||||
pscript = """
|
||||
import sys, os
|
||||
from time import sleep
|
||||
|
||||
_, d_path, in_filename, out_filename, first_flag_filename, second_flag_filename = sys.argv
|
||||
with open(os.sep.join((d_path, out_filename)), 'w') as out_file:
|
||||
out_file.write("")
|
||||
while True:
|
||||
if os.path.exists(os.sep.join((d_path, in_filename))):
|
||||
break
|
||||
sleep(0.1)
|
||||
if not os.path.exists(os.sep.join((d_path, first_flag_filename))):
|
||||
with open(os.sep.join((d_path, first_flag_filename)), 'w') as first_flag_file:
|
||||
first_flag_file.write("")
|
||||
else: # should never happen because of skipping
|
||||
with open(os.sep.join((d_path, second_flag_filename)), 'w') as second_flag_file:
|
||||
second_flag_file.write("")
|
||||
"""
|
||||
|
||||
pscript_loc = os.sep.join((str(d), "pscript.py"))
|
||||
with open(pscript_loc, "w") as pscript_file:
|
||||
pscript_file.write(pscript)
|
||||
os.chmod(pscript_loc, 0o777)
|
||||
project = {
|
||||
"commands": [
|
||||
{
|
||||
"name": "commandA",
|
||||
"script": [
|
||||
" ".join(("python", pscript_loc, str(d), "a", "b", "c", "d"))
|
||||
],
|
||||
"outputs": [os.sep.join((str(d), "f"))],
|
||||
},
|
||||
{
|
||||
"name": "commandB",
|
||||
"script": [
|
||||
" ".join(("python", pscript_loc, str(d), "b", "a", "e", "f"))
|
||||
],
|
||||
"outputs": [os.sep.join((str(d), "e"))],
|
||||
},
|
||||
],
|
||||
"workflows": {
|
||||
"all": [
|
||||
{"parallel": ["commandA", "commandB"]},
|
||||
{"parallel": ["commandB", "commandA"]},
|
||||
]
|
||||
},
|
||||
}
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
load_project_config(d)
|
||||
project_run(d, "all")
|
||||
assert os.path.exists(os.sep.join((str(d), "c")))
|
||||
assert os.path.exists(os.sep.join((str(d), "e")))
|
||||
assert not os.path.exists(os.sep.join((str(d), "d")))
|
||||
assert not os.path.exists(os.sep.join((str(d), "f")))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_parallel_processes", [2, 3, 4, 5, 6])
|
||||
def test_project_run_multiprocessing_max_processes_good_case(max_parallel_processes):
|
||||
with make_tempdir() as d:
|
||||
|
||||
project = {
|
||||
"max_parallel_processes": max_parallel_processes,
|
||||
"commands": [
|
||||
{
|
||||
"name": "commandA",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "A"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandB",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "B"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandC",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "C"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandD",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "D"))))],
|
||||
},
|
||||
{
|
||||
"name": "commandE",
|
||||
"script": [" ".join(("touch", os.sep.join((str(d), "E"))))],
|
||||
},
|
||||
],
|
||||
"workflows": {
|
||||
"all": [
|
||||
{
|
||||
"parallel": [
|
||||
"commandA",
|
||||
"commandB",
|
||||
"commandC",
|
||||
"commandD",
|
||||
"commandE",
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
load_project_config(d)
|
||||
project_run(d, "all")
|
||||
assert os.path.exists(os.sep.join((str(d), "A")))
|
||||
assert os.path.exists(os.sep.join((str(d), "B")))
|
||||
assert os.path.exists(os.sep.join((str(d), "C")))
|
||||
assert os.path.exists(os.sep.join((str(d), "D")))
|
||||
assert os.path.exists(os.sep.join((str(d), "E")))
|
Loading…
Reference in New Issue
Block a user