Mypy corrections

This commit is contained in:
richardpaulhudson 2022-05-09 19:03:57 +02:00
parent 12e86004c8
commit e3b4ee7b15
2 changed files with 15 additions and 6 deletions

View File

@ -1,3 +1,4 @@
import multiprocessing
from pathlib import Path
from wasabi import msg
from .remote_storage import RemoteStorage
@ -37,6 +38,7 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False):
# We use a while loop here because we don't know how the commands
# will be ordered. A command might need dependencies from one that's later
# in the list.
mult_group_mutex = multiprocessing.Lock()
while commands:
for i, cmd in enumerate(list(commands)):
logger.debug(f"CMD: {cmd['name']}.")
@ -52,7 +54,7 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False):
out_locs = [project_dir / out for out in cmd.get("outputs", [])]
if all(loc.exists() for loc in out_locs):
update_lockfile(project_dir, cmd)
update_lockfile(project_dir, cmd, mult_group_mutex=mult_group_mutex)
# We remove the command from the list here, and break, so that
# we iterate over the loop again.
commands.pop(i)

View File

@ -1,6 +1,7 @@
from typing import Optional, List, Dict, Sequence, Any, Iterable, cast
from pathlib import Path
from multiprocessing import Process, Lock
from multiprocessing.synchronize import Lock as Lock_t
from wasabi import msg
from wasabi.util import locale_escape
import sys
@ -51,7 +52,7 @@ def project_run(
force: bool = False,
dry: bool = False,
capture: bool = False,
mult_group_mutex: Optional[Lock] = None,
mult_group_mutex: Optional[Lock_t] = None,
) -> None:
"""Run a named script defined in the project.yml. If the script is part
of the default pipeline (defined in the "run" section), DVC is used to
@ -89,7 +90,8 @@ def project_run(
mult_group_mutex=mult_group_mutex,
)
else:
processes = [Process(
processes = [
Process(
target=project_run,
args=(project_dir, cmd),
kwargs={
@ -98,7 +100,10 @@ def project_run(
"dry": dry,
"capture": capture,
"mult_group_mutex": mult_group_mutex,
}) for cmd in cmdOrMultiprocessingGroup]
},
)
for cmd in cmdOrMultiprocessingGroup
]
for process in processes:
process.start()
for process in processes:
@ -238,7 +243,7 @@ def check_rerun(
*,
check_spacy_version: bool = True,
check_spacy_commit: bool = False,
mult_group_mutex: Lock,
mult_group_mutex: Lock_t,
) -> bool:
"""Check if a command should be rerun because its settings or inputs/outputs
changed.
@ -286,7 +291,9 @@ def check_rerun(
def update_lockfile(
project_dir: Path, command: Dict[str, Any], mult_group_mutex: Lock
project_dir: Path,
command: Dict[str, Any],
mult_group_mutex: Lock_t,
) -> None:
"""Update the lockfile after running a command. Will create a lockfile if
it doesn't yet exist and will add an entry for the current command, its