From e3b4ee7b151deba540929be1fe783f1ed95752d0 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Mon, 9 May 2022 19:03:57 +0200 Subject: [PATCH] Mypy corrections --- spacy/cli/project/pull.py | 4 +++- spacy/cli/project/run.py | 17 ++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/spacy/cli/project/pull.py b/spacy/cli/project/pull.py index 6e3cde88c..9496cdb1a 100644 --- a/spacy/cli/project/pull.py +++ b/spacy/cli/project/pull.py @@ -1,3 +1,4 @@ +import multiprocessing from pathlib import Path from wasabi import msg from .remote_storage import RemoteStorage @@ -37,6 +38,7 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False): # We use a while loop here because we don't know how the commands # will be ordered. A command might need dependencies from one that's later # in the list. + mult_group_mutex = multiprocessing.Lock() while commands: for i, cmd in enumerate(list(commands)): logger.debug(f"CMD: {cmd['name']}.") @@ -52,7 +54,7 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False): out_locs = [project_dir / out for out in cmd.get("outputs", [])] if all(loc.exists() for loc in out_locs): - update_lockfile(project_dir, cmd) + update_lockfile(project_dir, cmd, mult_group_mutex=mult_group_mutex) # We remove the command from the list here, and break, so that # we iterate over the loop again. commands.pop(i) diff --git a/spacy/cli/project/run.py b/spacy/cli/project/run.py index 24b827083..da1bdef90 100644 --- a/spacy/cli/project/run.py +++ b/spacy/cli/project/run.py @@ -1,6 +1,7 @@ from typing import Optional, List, Dict, Sequence, Any, Iterable, cast from pathlib import Path from multiprocessing import Process, Lock +from multiprocessing.synchronize import Lock as Lock_t from wasabi import msg from wasabi.util import locale_escape import sys @@ -51,7 +52,7 @@ def project_run( force: bool = False, dry: bool = False, capture: bool = False, - mult_group_mutex: Optional[Lock] = None, + mult_group_mutex: Optional[Lock_t] = None, ) -> None: """Run a named script defined in the project.yml. If the script is part of the default pipeline (defined in the "run" section), DVC is used to @@ -89,7 +90,8 @@ def project_run( mult_group_mutex=mult_group_mutex, ) else: - processes = [Process( + processes = [ + Process( target=project_run, args=(project_dir, cmd), kwargs={ @@ -98,7 +100,10 @@ def project_run( "dry": dry, "capture": capture, "mult_group_mutex": mult_group_mutex, - }) for cmd in cmdOrMultiprocessingGroup] + }, + ) + for cmd in cmdOrMultiprocessingGroup + ] for process in processes: process.start() for process in processes: @@ -238,7 +243,7 @@ def check_rerun( *, check_spacy_version: bool = True, check_spacy_commit: bool = False, - mult_group_mutex: Lock, + mult_group_mutex: Lock_t, ) -> bool: """Check if a command should be rerun because its settings or inputs/outputs changed. @@ -286,7 +291,9 @@ def check_rerun( def update_lockfile( - project_dir: Path, command: Dict[str, Any], mult_group_mutex: Lock + project_dir: Path, + command: Dict[str, Any], + mult_group_mutex: Lock_t, ) -> None: """Update the lockfile after running a command. Will create a lockfile if it doesn't yet exist and will add an entry for the current command, its