mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Basic multiprocessing functionality
This commit is contained in:
parent
8d08a68174
commit
12e86004c8
|
@ -249,10 +249,16 @@ def validate_project_commands(config: Dict[str, Any]) -> None:
|
|||
workflow_list = cast(List[str], step_or_list)
|
||||
if len(workflow_list) < 2:
|
||||
msg.fail(
|
||||
f"Invalid multiprocessing group within '{workflow_name}'",
|
||||
f"Invalid multiprocessing group within '{workflow_name}'.",
|
||||
f"A multiprocessing group must reference at least two commands.",
|
||||
exits=1,
|
||||
)
|
||||
if len(workflow_list) != len(set(workflow_list)):
|
||||
msg.fail(
|
||||
f"A multiprocessing group within '{workflow_name}' contains a command more than once.",
|
||||
f"This is not permitted because it is then not possible to determine when to rerun.",
|
||||
exits=1,
|
||||
)
|
||||
for step in workflow_list:
|
||||
verify_workflow_step(step)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional, List, Dict, Sequence, Any, Iterable
|
||||
from typing import Optional, List, Dict, Sequence, Any, Iterable, cast
|
||||
from pathlib import Path
|
||||
from multiprocessing import Process, Lock
|
||||
from wasabi import msg
|
||||
from wasabi.util import locale_escape
|
||||
import sys
|
||||
|
@ -50,6 +51,7 @@ def project_run(
|
|||
force: bool = False,
|
||||
dry: bool = False,
|
||||
capture: bool = False,
|
||||
mult_group_mutex: Optional[Lock] = None,
|
||||
) -> None:
|
||||
"""Run a named script defined in the project.yml. If the script is part
|
||||
of the default pipeline (defined in the "run" section), DVC is used to
|
||||
|
@ -67,21 +69,40 @@ def project_run(
|
|||
when you want to turn over execution to the command, and capture=True
|
||||
when you want to run the command more like a function.
|
||||
"""
|
||||
if mult_group_mutex is None:
|
||||
mult_group_mutex = Lock()
|
||||
config = load_project_config(project_dir, overrides=overrides)
|
||||
commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
|
||||
workflows = config.get("workflows", {})
|
||||
validate_subcommand(list(commands.keys()), list(workflows.keys()), subcommand)
|
||||
if subcommand in workflows:
|
||||
msg.info(f"Running workflow '{subcommand}'")
|
||||
for cmd in workflows[subcommand]:
|
||||
project_run(
|
||||
project_dir,
|
||||
cmd,
|
||||
overrides=overrides,
|
||||
force=force,
|
||||
dry=dry,
|
||||
capture=capture,
|
||||
)
|
||||
for cmdOrMultiprocessingGroup in workflows[subcommand]:
|
||||
if isinstance(cmdOrMultiprocessingGroup, str):
|
||||
project_run(
|
||||
project_dir,
|
||||
cmdOrMultiprocessingGroup,
|
||||
overrides=overrides,
|
||||
force=force,
|
||||
dry=dry,
|
||||
capture=capture,
|
||||
mult_group_mutex=mult_group_mutex,
|
||||
)
|
||||
else:
|
||||
processes = [Process(
|
||||
target=project_run,
|
||||
args=(project_dir, cmd),
|
||||
kwargs={
|
||||
"overrides": overrides,
|
||||
"force": force,
|
||||
"dry": dry,
|
||||
"capture": capture,
|
||||
"mult_group_mutex": mult_group_mutex,
|
||||
}) for cmd in cmdOrMultiprocessingGroup]
|
||||
for process in processes:
|
||||
process.start()
|
||||
for process in processes:
|
||||
process.join()
|
||||
else:
|
||||
cmd = commands[subcommand]
|
||||
for dep in cmd.get("deps", []):
|
||||
|
@ -93,13 +114,18 @@ def project_run(
|
|||
check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION)
|
||||
with working_dir(project_dir) as current_dir:
|
||||
msg.divider(subcommand)
|
||||
rerun = check_rerun(current_dir, cmd, check_spacy_commit=check_spacy_commit)
|
||||
rerun = check_rerun(
|
||||
current_dir,
|
||||
cmd,
|
||||
check_spacy_commit=check_spacy_commit,
|
||||
mult_group_mutex=mult_group_mutex,
|
||||
)
|
||||
if not rerun and not force:
|
||||
msg.info(f"Skipping '{cmd['name']}': nothing changed")
|
||||
else:
|
||||
run_commands(cmd["script"], dry=dry, capture=capture)
|
||||
if not dry:
|
||||
update_lockfile(current_dir, cmd)
|
||||
update_lockfile(current_dir, cmd, mult_group_mutex=mult_group_mutex)
|
||||
|
||||
|
||||
def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
|
||||
|
@ -157,7 +183,7 @@ def run_commands(
|
|||
|
||||
commands (List[str]): The string commands.
|
||||
silent (bool): Don't print the commands.
|
||||
dry (bool): Perform a dry run and don't execut anything.
|
||||
dry (bool): Perform a dry run and don't execute anything.
|
||||
capture (bool): Whether to capture the output and errors of individual commands.
|
||||
If False, the stdout and stderr will not be redirected, and if there's an error,
|
||||
sys.exit will be called with the return code. You should use capture=False
|
||||
|
@ -212,6 +238,7 @@ def check_rerun(
|
|||
*,
|
||||
check_spacy_version: bool = True,
|
||||
check_spacy_commit: bool = False,
|
||||
mult_group_mutex: Lock,
|
||||
) -> bool:
|
||||
"""Check if a command should be rerun because its settings or inputs/outputs
|
||||
changed.
|
||||
|
@ -224,51 +251,60 @@ def check_rerun(
|
|||
# Always rerun if no-skip is set
|
||||
if command.get("no_skip", False):
|
||||
return True
|
||||
lock_path = project_dir / PROJECT_LOCK
|
||||
if not lock_path.exists(): # We don't have a lockfile, run command
|
||||
return True
|
||||
data = srsly.read_yaml(lock_path)
|
||||
if command["name"] not in data: # We don't have info about this command
|
||||
return True
|
||||
entry = data[command["name"]]
|
||||
# Always run commands with no outputs (otherwise they'd always be skipped)
|
||||
if not entry.get("outs", []):
|
||||
return True
|
||||
# Always rerun if spaCy version or commit hash changed
|
||||
spacy_v = entry.get("spacy_version")
|
||||
commit = entry.get("spacy_git_version")
|
||||
if check_spacy_version and not is_minor_version_match(spacy_v, about.__version__):
|
||||
info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)"
|
||||
msg.info(f"Re-running '{command['name']}': spaCy minor version changed {info}")
|
||||
return True
|
||||
if check_spacy_commit and commit != GIT_VERSION:
|
||||
info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)"
|
||||
msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}")
|
||||
return True
|
||||
# If the entry in the lockfile matches the lockfile entry that would be
|
||||
# generated from the current command, we don't rerun because it means that
|
||||
# all inputs/outputs, hashes and scripts are the same and nothing changed
|
||||
lock_entry = get_lock_entry(project_dir, command)
|
||||
exclude = ["spacy_version", "spacy_git_version"]
|
||||
return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude)
|
||||
with mult_group_mutex:
|
||||
lock_path = project_dir / PROJECT_LOCK
|
||||
if not lock_path.exists(): # We don't have a lockfile, run command
|
||||
return True
|
||||
data = srsly.read_yaml(lock_path)
|
||||
if command["name"] not in data: # We don't have info about this command
|
||||
return True
|
||||
entry = data[command["name"]]
|
||||
# Always run commands with no outputs (otherwise they'd always be skipped)
|
||||
if not entry.get("outs", []):
|
||||
return True
|
||||
# Always rerun if spaCy version or commit hash changed
|
||||
spacy_v = entry.get("spacy_version")
|
||||
commit = entry.get("spacy_git_version")
|
||||
if check_spacy_version and not is_minor_version_match(
|
||||
spacy_v, about.__version__
|
||||
):
|
||||
info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)"
|
||||
msg.info(
|
||||
f"Re-running '{command['name']}': spaCy minor version changed {info}"
|
||||
)
|
||||
return True
|
||||
if check_spacy_commit and commit != GIT_VERSION:
|
||||
info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)"
|
||||
msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}")
|
||||
return True
|
||||
# If the entry in the lockfile matches the lockfile entry that would be
|
||||
# generated from the current command, we don't rerun because it means that
|
||||
# all inputs/outputs, hashes and scripts are the same and nothing changed
|
||||
lock_entry = get_lock_entry(project_dir, command)
|
||||
exclude = ["spacy_version", "spacy_git_version"]
|
||||
return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude)
|
||||
|
||||
|
||||
def update_lockfile(project_dir: Path, command: Dict[str, Any]) -> None:
|
||||
def update_lockfile(
|
||||
project_dir: Path, command: Dict[str, Any], mult_group_mutex: Lock
|
||||
) -> None:
|
||||
"""Update the lockfile after running a command. Will create a lockfile if
|
||||
it doesn't yet exist and will add an entry for the current command, its
|
||||
script and dependencies/outputs.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
command (Dict[str, Any]): The command, as defined in the project.yml.
|
||||
mult_group_mutex: the mutex preventing concurrent writes
|
||||
"""
|
||||
lock_path = project_dir / PROJECT_LOCK
|
||||
if not lock_path.exists():
|
||||
srsly.write_yaml(lock_path, {})
|
||||
data = {}
|
||||
else:
|
||||
data = srsly.read_yaml(lock_path)
|
||||
data[command["name"]] = get_lock_entry(project_dir, command)
|
||||
srsly.write_yaml(lock_path, data)
|
||||
with mult_group_mutex:
|
||||
lock_path = project_dir / PROJECT_LOCK
|
||||
if not lock_path.exists():
|
||||
srsly.write_yaml(lock_path, {})
|
||||
data = {}
|
||||
else:
|
||||
data = srsly.read_yaml(lock_path)
|
||||
data[command["name"]] = get_lock_entry(project_dir, command)
|
||||
srsly.write_yaml(lock_path, data)
|
||||
|
||||
|
||||
def get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
|
|
@ -19,6 +19,7 @@ from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
|
|||
from spacy.cli.package import get_third_party_dependencies
|
||||
from spacy.cli.package import _is_permitted_package_name
|
||||
from spacy.cli.validate import get_model_pkgs
|
||||
from spacy.cli.project.run import project_run
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.nl import Dutch
|
||||
from spacy.language import Language
|
||||
|
@ -440,6 +441,7 @@ def test_project_config_multiprocessing_good_case():
|
|||
[
|
||||
{"all": ["command1", ["command2"], "command3"]},
|
||||
{"all": ["command1", ["command2", "command4"]]},
|
||||
{"all": ["command1", ["command2", "command2"]]},
|
||||
],
|
||||
)
|
||||
def test_project_config_multiprocessing_bad_case(workflows):
|
||||
|
@ -457,6 +459,60 @@ def test_project_config_multiprocessing_bad_case(workflows):
|
|||
load_project_config(d)
|
||||
|
||||
|
||||
def test_project_run_multiprocessing_good_case():
|
||||
with make_tempdir() as d:
|
||||
|
||||
pscript = """
|
||||
import sys, os
|
||||
from time import sleep
|
||||
|
||||
_, d_path, in_filename, out_filename, first_flag_filename, second_flag_filename = sys.argv
|
||||
with open(os.sep.join((d_path, out_filename)), 'w') as out_file:
|
||||
out_file.write("")
|
||||
while True:
|
||||
if os.path.exists(os.sep.join((d_path, in_filename))):
|
||||
break
|
||||
sleep(0.1)
|
||||
if not os.path.exists(os.sep.join((d_path, first_flag_filename))):
|
||||
with open(os.sep.join((d_path, first_flag_filename)), 'w') as first_flag_file:
|
||||
first_flag_file.write("")
|
||||
else: # should never happen because of skipping
|
||||
with open(os.sep.join((d_path, second_flag_filename)), 'w') as second_flag_file:
|
||||
second_flag_file.write("")
|
||||
"""
|
||||
|
||||
pscript_loc = os.sep.join((str(d), "pscript.py"))
|
||||
with open(pscript_loc, "w") as pscript_file:
|
||||
pscript_file.write(pscript)
|
||||
os.chmod(pscript_loc, 0o777)
|
||||
project = {
|
||||
"commands": [
|
||||
{
|
||||
"name": "commandA",
|
||||
"script": [
|
||||
" ".join(("python", pscript_loc, str(d), "a", "b", "c", "d"))
|
||||
],
|
||||
"outputs": [" ".join((str(d), "c"))],
|
||||
},
|
||||
{
|
||||
"name": "commandB",
|
||||
"script": [
|
||||
" ".join(("python", pscript_loc, str(d), "b", "a", "e", "f"))
|
||||
],
|
||||
"outputs": [" ".join((str(d), "e"))],
|
||||
},
|
||||
],
|
||||
"workflows": {"all": [["commandA", "commandB"], ["commandA", "commandB"]]},
|
||||
}
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
load_project_config(d)
|
||||
project_run(d, "all")
|
||||
assert os.path.exists(os.sep.join((str(d), "c")))
|
||||
assert os.path.exists(os.sep.join((str(d), "e")))
|
||||
assert not os.path.exists(os.sep.join((str(d), "d")))
|
||||
assert not os.path.exists(os.sep.join((str(d), "f")))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"greeting",
|
||||
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],
|
||||
|
|
Loading…
Reference in New Issue
Block a user