Skip to content

Commit 6e66163

Browse files
committed
Random tag for the runtime created configmap - Remove the runtime created configmap when the pod is removed - First basic tests for Dask
1 parent a27849c commit 6e66163

File tree

4 files changed

+143
-18
lines changed

4 files changed

+143
-18
lines changed

Diff for: calrissian/dask.py

+31-16
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
from calrissian.retry import retry_exponential_if_exception_type
1717
from calrissian.job import (
1818
CalrissianCommandLineJob,
19-
KubernetesPodBuilder,
19+
KubernetesPodBuilder
2020
)
2121
from calrissian.job import (
2222
quoted_arg_list,
23-
read_yaml
23+
read_yaml,
24+
random_tag,
25+
k8s_safe_name,
2426
)
2527
from calrissian.job import (
2628
DEFAULT_INIT_IMAGE,
@@ -246,8 +248,6 @@ class CalrissianCommandLineDaskJob(CalrissianCommandLineJob):
246248
daskGateway_controller_dir = '/controller'
247249

248250
daskGateway_config_dir = '/etc/dask'
249-
daskGateway_cm_name = 'dask-gateway-cm'
250-
daskGateway_cm = 'dask-gateway-cm'
251251

252252
daskGateway_controller_cm_name = 'dask-cluster-controller-cm'
253253

@@ -256,8 +256,14 @@ def __init__(self, *args, **kwargs):
256256
super(CalrissianCommandLineDaskJob, self).__init__(*args, **kwargs)
257257
self.client = KubernetesDaskClient()
258258

259-
def wait_for_kubernetes_pod(self):
260-
return self.client.wait_for_completion()
259+
self.dask_cm_name, self.dask_cm_claim_name = self.dask_configmap_name()
260+
261+
def dask_configmap_name(self):
262+
tag = random_tag()
263+
return k8s_safe_name('{}-cm-{}'.format('dask', tag)), k8s_safe_name('{}-cm-{}'.format('dask', tag))
264+
265+
def wait_for_kubernetes_pod(self, cm_name: str):
266+
return self.client.wait_for_completion(cm_name = cm_name)
261267

262268
def get_dask_gateway_url(self, runtimeContext):
263269
return runtimeContext.dask_gateway_url
@@ -301,7 +307,8 @@ def create_kubernetes_runtime(self, runtimeContext):
301307

302308

303309
self.client.create_dask_gateway_cofig_map(
304-
dask_gateway_url=self.get_dask_gateway_url(runtimeContext))
310+
dask_gateway_url=self.get_dask_gateway_url(runtimeContext),
311+
cm_name=self.dask_cm_name)
305312

306313
# emptyDir volume at /shared for sharing the Dask cluster name between containers
307314
self._add_emptydir_volume_and_binding('shared-data', self.container_shared_dir)
@@ -310,8 +317,8 @@ def create_kubernetes_runtime(self, runtimeContext):
310317
# Need this ConfigMap to simplify configuration by providing defaults,
311318
# as explained here: https://gateway.dask.org/configuration-user.html
312319
self._add_configmap_volume_and_binding(
313-
name=self.daskGateway_cm,
314-
cm_name=self.daskGateway_cm_name,
320+
name=self.dask_cm_name,
321+
cm_name=self.dask_cm_claim_name,
315322
target=self.daskGateway_config_dir)
316323

317324

@@ -375,7 +382,7 @@ def get_pod_name(pod):
375382

376383
pod = self.create_kubernetes_runtime(runtimeContext) # analogous to create_runtime()
377384
self.execute_kubernetes_pod(pod) # analogous to _execute()
378-
completion_result = self.wait_for_kubernetes_pod()
385+
completion_result = self.wait_for_kubernetes_pod(cm_name = self.dask_cm_name)
379386
if completion_result.exit_code != 0:
380387
log_main.error(f"ERROR the command below failed in pod {get_pod_name(pod)}:")
381388
log_main.error("\t" + " ".join(get_pod_command(pod)))
@@ -387,6 +394,15 @@ class KubernetesDaskClient(KubernetesClient):
387394
def __init__(self):
388395
super().__init__()
389396

397+
@retry_exponential_if_exception_type((ApiException, HTTPError,), log)
398+
def submit_pod(self, pod_body):
399+
with DaskPodMonitor() as monitor:
400+
pod = self.core_api_instance.create_namespaced_pod(self.namespace, pod_body)
401+
log.info('Created k8s pod name {} with id {}'.format(pod.metadata.name, pod.metadata.uid))
402+
monitor.add(pod)
403+
self._set_pod(pod)
404+
405+
390406
@retry_exponential_if_exception_type((ApiException, HTTPError,), log)
391407
def follow_logs(self, status):
392408
pod_name = self.pod.metadata.name
@@ -409,7 +425,7 @@ def follow_logs(self, status):
409425

410426

411427
@retry_exponential_if_exception_type((ApiException, HTTPError, IncompleteStatusException), log)
412-
def wait_for_completion(self) -> CompletionResult:
428+
def wait_for_completion(self, cm_name: str) -> CompletionResult:
413429
w = watch.Watch()
414430
for event in w.stream(self.core_api_instance.list_namespaced_pod, self.namespace, field_selector=self._get_pod_field_selector()):
415431
pod = event['object']
@@ -439,7 +455,7 @@ def wait_for_completion(self) -> CompletionResult:
439455
if self.should_delete_pod():
440456
with DaskPodMonitor() as monitor:
441457
self.delete_pod_name(pod.metadata.name)
442-
self.delete_configmap_name(cm_name="dask-gateway-cm")
458+
self.delete_configmap_name(cm_name=cm_name)
443459
monitor.remove(pod)
444460
self._clear_pod()
445461
# stop watching for events, our pod is done. Causes wait loop to exit
@@ -469,11 +485,11 @@ def get_last_or_none(container_list: List[Union[V1ContainerStatus, V1Container]]
469485
return container_list[-1]
470486

471487
@retry_exponential_if_exception_type((ApiException, HTTPError,), log)
472-
def create_dask_gateway_cofig_map(self, dask_gateway_url: str):
488+
def create_dask_gateway_cofig_map(self, dask_gateway_url: str, cm_name: str):
473489
gateway = {'gateway': {'address': dask_gateway_url}}
474490

475491
configmap = client.V1ConfigMap(
476-
metadata=client.V1ObjectMeta(name="dask-gateway-cm"),
492+
metadata=client.V1ObjectMeta(name=cm_name),
477493
data={
478494
"gateway.yaml": yaml.dump(gateway)
479495
}
@@ -518,7 +534,7 @@ def delete_configmap_name(self, cm_name):
518534
class DaskPodMonitor(PodMonitor):
519535
def __init__(self):
520536
super().__init__()
521-
537+
522538
@staticmethod
523539
def cleanup():
524540
log.info('Starting Cleanup')
@@ -528,7 +544,6 @@ def cleanup():
528544
log.info('PodMonitor deleting pod {}'.format(pod_name))
529545
try:
530546
k8s_client.delete_pod_name(pod_name)
531-
k8s_client.delete_configmap_name(cm_name="dask-gateway-cm")
532547
except Exception:
533548
log.error('Error deleting pod named {}, ignoring'.format(pod_name))
534549
PodMonitor.pod_names = []

Diff for: calrissian/main.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from calrissian.version import version
99
from calrissian.k8s import delete_pods
1010
from calrissian.report import initialize_reporter, write_report, CPUParser, MemoryParser
11+
from calrissian.dask import DaskPodMonitor
1112
from cwltool.main import main as cwlmain
1213
from cwltool.argparser import arg_parser
1314
from typing_extensions import Text
@@ -101,7 +102,6 @@ def install_signal_handler():
101102
"""
102103
signal.signal(signal.SIGTERM, handle_sigterm)
103104

104-
105105
def install_tees(stdout_path=None, stderr_path=None):
106106
"""
107107
Reconnects stdout/stderr to `tee` processes via subprocess.PIPE that can write to user-supplied files
@@ -174,7 +174,10 @@ def main():
174174
)
175175
finally:
176176
# Always clean up after cwlmain
177-
delete_pods()
177+
if parsed_args.dask_gateway_url:
178+
DaskPodMonitor.cleanup()
179+
else:
180+
delete_pods()
178181
if parsed_args.usage_report:
179182
write_report(parsed_args.usage_report)
180183
flush_tees()

Diff for: pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ RETRY_ATTEMPTS="0"
122122
[tool.hatch.envs.test.scripts]
123123
test = "hatch run nose2"
124124
testv = "hatch run nose2 --verbose"
125+
testdask = "hatch run nose2 tests.test_dask"
125126
cov = ["coverage run --source=calrissian -m nose2", "coverage report"]
126127

127128
[[tool.hatch.envs.test.matrix]]

Diff for: tests/test_dask.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import os
2+
3+
from unittest import TestCase
4+
from unittest.mock import patch, call, Mock, mock_open
5+
import logging
6+
7+
from cwltool.utils import CWLObjectType
8+
9+
from calrissian.dask import (
10+
KubernetesDaskPodBuilder
11+
)
12+
13+
from calrissian.dask import (
14+
dask_req_validate,
15+
)
16+
17+
18+
class ValidateExtensionTestCase(TestCase):
19+
20+
def setUp(self):
21+
self.daskGatewayRequirement: CWLObjectType = {
22+
"workerCores": 2,
23+
"workerCoresLimit": 2,
24+
"workerMemory": "4G",
25+
"clustermaxCore": 8,
26+
"clusterMaxMemory": "16G",
27+
"class": "https://calrissian-cwl.github.io/schema#DaskGatewayRequirement" # From cwl
28+
}
29+
30+
def tests_validate_extension(self):
31+
self.assertTrue(dask_req_validate(self.daskGatewayRequirement))
32+
33+
34+
class KubernetesDaskPodBuilderTestCase(TestCase):
35+
36+
def setUp(self):
37+
self.name = 'PodName'
38+
self.container_image = 'dockerimage:1.0'
39+
self.volume_mounts = [Mock(), Mock()]
40+
self.volumes = [Mock()]
41+
self.command_line = ['cat']
42+
self.stdout = 'stdout.txt'
43+
self.stderr = 'stderr.txt'
44+
self.stdin = 'stdin.txt'
45+
self.resources = {'cores': 1, 'ram': 1024}
46+
self.labels = {'key1': 'val1', 'key2': 123}
47+
self.nodeselectors = {'disktype': 'ssd', 'cachelevel': 2}
48+
self.security_context = { 'runAsUser': os.getuid(),'runAsGroup': os.getgid() }
49+
self.pod_serviceaccount = "podmanager"
50+
self.dask_gateway_url = "http://dask-gateway-url:80"
51+
self.dask_gateway_controller = False
52+
self.environment = {
53+
'HOME': '/homedir',
54+
'PYTHONPATH': '/app',
55+
}
56+
self.pod_builder = KubernetesDaskPodBuilder(self.name, self.container_image, self.environment, self.volume_mounts,
57+
self.volumes, self.command_line, self.stdout, self.stderr, self.stdin,
58+
self.resources, self.labels, self.nodeselectors, self.security_context, self.pod_serviceaccount,
59+
self.dask_gateway_url, self.dask_gateway_controller)
60+
self.pod_builder.dask_requirement = {
61+
"workerCores": 2,
62+
"workerCoresLimit": 2,
63+
"workerMemory": "4G",
64+
"clustermaxCore": 8,
65+
"clusterMaxMemory": "16G",
66+
"class": "https://calrissian-cwl.github.io/schema#DaskGatewayRequirement" # From cwl
67+
}
68+
69+
def test_main_container_args_without_redirects(self):
70+
# container_args returns a list with a single item since it is passed to 'sh', '-c'
71+
self.pod_builder.stdout = None
72+
self.pod_builder.stderr = None
73+
self.pod_builder.stdin = None
74+
self.assertEqual(['set -e; trap "touch /shared/completed" EXIT;export DASK_CLUSTER=$(cat /shared/dask_cluster_name.txt) ; cat' ], self.pod_builder.container_args())
75+
76+
77+
def test_container_args_with_redirects(self):
78+
self.assertEqual(['set -e; trap "touch /shared/completed" EXIT;export DASK_CLUSTER=$(cat /shared/dask_cluster_name.txt) ; cat > stdout.txt 2> stderr.txt < stdin.txt'], self.pod_builder.container_args())
79+
80+
81+
def test_init_container_command_with_external_script(self):
82+
self.pod_builder.dask_gateway_controller = True
83+
self.assertEqual(['python', '/controller/init-dask.py'], self.pod_builder.init_container_command())
84+
85+
86+
@patch("builtins.open", new_callable=mock_open, read_data="print('Default script')")
87+
@patch("os.path.join", return_value="/mocked/path/init-dask.py") # Mock path join
88+
def test_init_container_command_with_default_script(self, mock_path_join, mock_file):
89+
self.pod_builder.dask_gateway_controller = False
90+
expected_command = ['python', '-c', "print('Default script')"]
91+
self.assertEqual(expected_command, self.pod_builder.init_container_command())
92+
93+
94+
def test_container_environment(self):
95+
environment = self.pod_builder.container_environment()
96+
self.assertEqual(len(self.environment) + 8, len(environment)) # +8 Because the dask related are added at runtime
97+
self.assertIn({'name': 'HOME', 'value': '/homedir'}, environment)
98+
self.assertIn({'name': 'PYTHONPATH', 'value': '/app'}, environment)
99+
self.assertIn({'name': 'DASK_GATEWAY_WORKER_CORES', 'value': '2'}, environment)
100+
self.assertIn({'name': 'DASK_GATEWAY_WORKER_MEMORY', 'value': '4G'}, environment)
101+
self.assertIn({'name': 'DASK_GATEWAY_WORKER_CORES_LIMIT', 'value': '2'}, environment)
102+
self.assertIn({'name': 'DASK_GATEWAY_CLUSTER_MAX_CORES', 'value': '8'}, environment)
103+
self.assertIn({'name': 'DASK_GATEWAY_CLUSTER_MAX_RAM', 'value': '16G'}, environment)
104+
self.assertIn({'name': 'DASK_GATEWAY_URL', 'value': 'http://dask-gateway-url:80'}, environment)
105+
self.assertIn({'name': 'DASK_GATEWAY_IMAGE', 'value': 'dockerimage:1.0'}, environment) # Replace with actual image if needed
106+
self.assertIn({'name': 'DASK_CLUSTER_NAME_PATH', 'value': '/shared/dask_cluster_name.txt'}, environment)

0 commit comments

Comments
 (0)