Skip to content

Commit 6ae2ea3

Browse files
committed
first try at supporting multi-node blocks
1 parent cfd48d3 commit 6ae2ea3

File tree

2 files changed

+73
-79
lines changed

2 files changed

+73
-79
lines changed

psiflow/execution.py

+38-73
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
from parsl.config import Config
2020
from parsl.data_provider.files import File
2121
from parsl.executors import HighThroughputExecutor, ThreadPoolExecutor
22-
from parsl.launchers.launchers import SimpleLauncher
22+
from parsl.launchers.launchers import SimpleLauncher, SrunLauncher
2323
from parsl.providers import * # noqa: F403
2424
from parsl.providers.base import ExecutionProvider
2525

2626
from psiflow.models import BaseModel
27-
from psiflow.parsl_utils import ContainerizedLauncher, MyWorkQueueExecutor
27+
from psiflow.parsl_utils import ContainerizedLauncher, ContainerizedSrunLauncher
2828
from psiflow.reference import BaseReference
2929
from psiflow.utils import resolve_and_check, set_logger
3030

@@ -197,6 +197,7 @@ class ExecutionContextLoader:
197197
def parse_config(yaml_dict: dict):
198198
definitions = []
199199

200+
container_dict = yaml_dict.pop("container", None)
200201
for name in ["ModelEvaluation", "ModelTraining", "ReferenceEvaluation"]:
201202
if name in yaml_dict:
202203
_dict = yaml_dict.pop(name)
@@ -223,25 +224,45 @@ def parse_config(yaml_dict: dict):
223224
s = _dict["mpi_command"]
224225
_dict["mpi_command"] = lambda x, s=s: s.format(x)
225226

226-
if "container" in yaml_dict:
227-
assert not _dict["use_threadpool"] # not possible with container
227+
# set up containerized launcher if necessary
228+
if ("container" not in _dict and container_dict is None) or _dict[
229+
"use_threadpool"
230+
]:
231+
launcher = SimpleLauncher()
232+
_container_dict = None
233+
else:
234+
_container_dict = yaml_dict.pop("container", container_dict)
235+
assert _container_dict is not None
228236
launcher = ContainerizedLauncher(
229-
**yaml_dict["container"], enable_gpu=_dict["gpu"]
237+
**_container_dict,
238+
enable_gpu=_dict["gpu"],
230239
)
231-
else:
232-
launcher = SimpleLauncher()
233240

234241
# initialize provider
235-
provider_dict = None
236-
for key in _dict:
237-
if "Provider" in key:
238-
assert provider_dict is None
239-
provider_dict = _dict[key]
240-
if provider_dict is not None:
241-
provider_cls = getattr(sys.modules[__name__], key)
242-
provider = provider_cls(launcher=launcher, **_dict.pop(key))
243-
else:
242+
provider_keys = list(filter(lambda k: "Provider" in k, _dict.keys()))
243+
if len(provider_keys) == 0:
244244
provider = LocalProvider(launcher=launcher) # noqa: F405
245+
elif len(provider_keys) == 1:
246+
provider_dict = _dict[provider_keys[0]]
247+
248+
# if provider requests multiple nodes, switch to (containerized) SrunLauncher
249+
if (
250+
provider_dict.pop("nodes_per_block", 1) > 1
251+
and "container" in yaml_dict
252+
):
253+
assert (
254+
provider_keys[0] == "SlurmProvider"
255+
), "multi-node blocks only supported for SLURM"
256+
if _container_dict is not None:
257+
launcher = ContainerizedSrunLauncher(
258+
**_container_dict, enable_gpu=_dict["gpu"]
259+
)
260+
else:
261+
launcher = SrunLauncher()
262+
provider_cls = getattr(sys.modules[__name__], provider_keys[0])
263+
provider = provider_cls(launcher=launcher, **provider_dict)
264+
else:
265+
raise ValueError("Can only have one provider per executor")
245266

246267
# initialize definition
247268
definition_cls = getattr(sys.modules[__name__], name)
@@ -259,7 +280,6 @@ def parse_config(yaml_dict: dict):
259280
"default_threads": 1,
260281
"mode": "htex",
261282
"htex_address": address_by_hostname(),
262-
"workqueue_use_coprocess": False, # CP2K doesn't like this
263283
}
264284
forced = {
265285
"initialize_logging": False, # manual; to move parsl.log one level up
@@ -319,18 +339,16 @@ def load(
319339
path.iterdir()
320340
), "internal directory {} should be empty".format(path)
321341
path.mkdir(parents=True, exist_ok=True)
322-
set_logger(psiflow_config.pop("psiflow_log_level"))
323342
parsl.set_file_logger(
324343
str(path / "parsl.log"),
325344
"parsl",
326345
getattr(logging, psiflow_config.pop("parsl_log_level")),
327-
# format_string="%(levelname)s - %(name)s - %(message)s",
328346
)
347+
set_logger(psiflow_config.pop("psiflow_log_level"))
329348

330349
# create main parsl executors
331350
executors = []
332351
mode = psiflow_config.pop("mode")
333-
use_coprocess = psiflow_config.pop("workqueue_use_coprocess")
334352
htex_address = psiflow_config.pop("htex_address")
335353
for definition in definitions:
336354
if definition.use_threadpool:
@@ -362,61 +380,8 @@ def load(
362380
provider=definition.parsl_provider,
363381
cpu_affinity=definition.cpu_affinity,
364382
)
365-
elif mode == "workqueue":
366-
worker_options = []
367-
if hasattr(definition.parsl_provider, "cores_per_node"):
368-
worker_options.append(
369-
"--cores={}".format(definition.parsl_provider.cores_per_node),
370-
)
371-
else:
372-
worker_options.append(
373-
"--cores={}".format(psutil.cpu_count(logical=False)),
374-
)
375-
if hasattr(definition.parsl_provider, "walltime"):
376-
walltime_hhmmss = definition.parsl_provider.walltime.split(":")
377-
assert len(walltime_hhmmss) == 3
378-
walltime = 0
379-
walltime += 60 * float(walltime_hhmmss[0])
380-
walltime += float(walltime_hhmmss[1])
381-
walltime += 1 # whatever seconds are present
382-
walltime -= (
383-
5 # add 5 minutes of slack, e.g. for container downloading
384-
)
385-
worker_options.append("--wall-time={}".format(walltime * 60))
386-
worker_options.append("--parent-death")
387-
worker_options.append(
388-
"--timeout={}".format(psiflow_config["max_idletime"])
389-
)
390-
# manager_config = TaskVineManagerConfig(
391-
# shared_fs=True,
392-
# max_retries=1,
393-
# autocategory=False,
394-
# enable_peer_transfers=False,
395-
# port=0,
396-
# )
397-
# factory_config = TaskVineFactoryConfig(
398-
# factory_timeout=20,
399-
# worker_options=' '.join(worker_options),
400-
# )
401-
executor = MyWorkQueueExecutor(
402-
label=definition.name(),
403-
working_dir=str(path / definition.name()),
404-
provider=definition.parsl_provider,
405-
shared_fs=True,
406-
autocategory=False,
407-
port=0,
408-
max_retries=0,
409-
coprocess=use_coprocess,
410-
worker_options=" ".join(worker_options),
411-
)
412383
else:
413384
raise ValueError("Unknown mode {}".format(mode))
414-
# executor = TaskVineExecutor(
415-
# label=definition.name(),
416-
# provider=definition.parsl_provider,
417-
# manager_config=manager_config,
418-
# factory_config=factory_config,
419-
# )
420385
executors.append(executor)
421386

422387
# create default executors

psiflow/parsl_utils.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,11 @@
44
from typing import Optional
55

66
import typeguard
7-
from parsl.executors import WorkQueueExecutor
87
from parsl.launchers.launchers import Launcher
98

109
logger = logging.getLogger(__name__)
1110

1211

13-
class MyWorkQueueExecutor(WorkQueueExecutor):
14-
def _get_launch_command(self, block_id):
15-
return self.worker_command
16-
17-
1812
ADDOPTS = " --no-eval -e --no-mount home -W /tmp --writable-tmpfs"
1913
ENTRYPOINT = "/usr/local/bin/entry.sh"
2014

@@ -66,3 +60,38 @@ def __init__(
6660

6761
def __call__(self, command: str, tasks_per_node: int, nodes_per_block: int) -> str:
6862
return self.launch_command + "{}".format(command)
63+
64+
65+
@typeguard.typechecked
66+
class ContainerizedSrunLauncher(ContainerizedLauncher):
67+
def __init__(self, overrides: str = "", **kwargs):
68+
self.overrides = overrides
69+
super().__init__(**kwargs)
70+
71+
def __call__(self, command: str, tasks_per_node: int, nodes_per_block: int) -> str:
72+
task_blocks = tasks_per_node * nodes_per_block
73+
debug_num = int(self.debug)
74+
75+
x = """set -e
76+
export CORES=$SLURM_CPUS_ON_NODE
77+
export NODES=$SLURM_JOB_NUM_NODES
78+
79+
[[ "{debug}" == "1" ]] && echo "Found cores : $CORES"
80+
[[ "{debug}" == "1" ]] && echo "Found nodes : $NODES"
81+
WORKERCOUNT={task_blocks}
82+
83+
cat << SLURM_EOF > cmd_$SLURM_JOB_NAME.sh
84+
{command}
85+
SLURM_EOF
86+
chmod a+x cmd_$SLURM_JOB_NAME.sh
87+
88+
srun --ntasks {task_blocks} -l {overrides} bash cmd_$SLURM_JOB_NAME.sh
89+
90+
[[ "{debug}" == "1" ]] && echo "Done"
91+
""".format(
92+
command=self.launch_command + "{}".format(command),
93+
task_blocks=task_blocks,
94+
overrides=self.overrides,
95+
debug=debug_num,
96+
)
97+
return x

0 commit comments

Comments
 (0)