Skip to content

Commit a27849c

Browse files
committed
Code refactor - add default init and dispose script - new optional configMap with custom init and dispose schema
1 parent a1f4c24 commit a27849c

File tree

8 files changed

+389
-154
lines changed

8 files changed

+389
-154
lines changed

Diff for: calrissian/context.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ def __init__(self, kwargs=None):
2323
self.pod_serviceaccount = None
2424
self.tool_logs_basepath = None
2525
self.max_gpus = None
26-
self.gateway_url = None
26+
self.dask_gateway_url = None
27+
self.dask_gateway_extra_config = None
2728
return super(CalrissianRuntimeContext, self).__init__(kwargs)

Diff for: calrissian/dask.py

+263-61
Large diffs are not rendered by default.
File renamed without changes.

Diff for: calrissian/dask/dispose-dask.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import time, os, sys
2+
import logging
3+
from dask_gateway import Gateway
4+
5+
logger = logging.getLogger(__name__)
6+
logging.basicConfig(level=logging.INFO)
7+
8+
source = os.environ.get("DASK_CLUSTER_NAME_PATH", None)
9+
gateway_url = os.environ.get("DASK_GATEWAY_URL", None)
10+
signal = "/shared/completed"
11+
12+
logger.info(f"Sidecar: Waiting for completion signal ({signal}) from main container...")
13+
14+
# Poll for the existence of the completion file
15+
while not os.path.exists(signal):
16+
logger.info("Sidecar: Waiting for completion signal from main container...")
17+
time.sleep(5)
18+
19+
logger.info("Sidecar: Completion signal received. Shutting down Dask cluster...")
20+
21+
# Shut down the Dask cluster
22+
with open(source, "r") as f:
23+
cluster_name = f.read().strip()
24+
25+
gateway = Gateway(gateway_url)
26+
cluster = gateway.connect(cluster_name)
27+
logger.info(f"Sidecar: Connected to Dask cluster: {cluster_name}")
28+
cluster.shutdown()
29+
logger.info("Sidecar: Dask cluster shut down successfully.")
30+
sys.exit(0)

Diff for: calrissian/dask/init-dask.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# this code is responsible for creating a Dask cluster
2+
# it's executed by the CWL runner in the context of the Dask Gateway extension
3+
# this is for the prototyping purposes only
4+
import os
5+
import re
6+
import logging
7+
from dask_gateway import Gateway
8+
9+
logger = logging.getLogger(__name__)
10+
logging.basicConfig(level=logging.INFO)
11+
12+
def parse_memory(mem_str):
13+
units = {
14+
"B": 1,
15+
"KB": 1000,
16+
"MB": 1000**2,
17+
"GB": 1000**3,
18+
"TB": 1000**4,
19+
"PB": 1000**5,
20+
"K": 1024,
21+
"M": 1024**2,
22+
"G": 1024**3,
23+
"T": 1024**4,
24+
"P": 1024**5,
25+
}
26+
27+
match = re.match(r"(\d+)([A-Za-z]+)", mem_str)
28+
if not match:
29+
raise ValueError(f"Invalid memory format: {mem_str}")
30+
31+
value, unit = match.groups()
32+
33+
if unit not in units:
34+
raise ValueError(f"Unknown unit: {unit}")
35+
36+
return int(value) * units[unit]
37+
38+
target = os.environ.get("DASK_CLUSTER_NAME_PATH", None)
39+
gateway_url = os.environ.get("DASK_GATEWAY_URL", None)
40+
image = os.environ.get("DASK_GATEWAY_IMAGE", None)
41+
worker_cores = os.environ.get("DASK_GATEWAY_WORKER_CORES", None)
42+
worker_cores_limit = os.environ.get("DASK_GATEWAY_WORKER_CORES_LIMIT", None)
43+
worker_memory = os.environ.get("DASK_GATEWAY_WORKER_MEMORY", None)
44+
max_cores = os.environ.get("DASK_GATEWAY_CLUSTER_MAX_CORES", None)
45+
max_ram = os.environ.get("DASK_GATEWAY_CLUSTER_MAX_RAM", None)
46+
47+
logger.info(f"Creating Dask cluster and saving the name to {target}")
48+
49+
gateway = Gateway(gateway_url)
50+
51+
cluster_options = gateway.cluster_options()
52+
53+
cluster_options['image'] = image
54+
cluster_options['worker_cores'] = float(worker_cores)
55+
cluster_options['worker_cores_limit'] = int(worker_cores_limit)
56+
57+
cluster_options['worker_memory'] = worker_memory
58+
#cluster_options["worker_extra_pod_labels"] = {"group": "dask"}
59+
60+
logger.info(f"Cluster options: {cluster_options}")
61+
logger.info(dir(cluster_options))
62+
cluster = gateway.new_cluster(cluster_options, shutdown_on_close=False)
63+
64+
# resource requirements
65+
#worker_cores = 0.5
66+
#worker_cores_limit = 1 # would come from DaskGateway.Requirement.ResourceRequirement.worker_cores_limit (or worker_cores)
67+
#worker_memory = 2 # would come from DaskGateway.Requirement.ResourceRequirement.worker_memory
68+
logger.info(f"Resource requirements: {worker_cores} cores, {worker_memory}")
69+
70+
# scale cluster
71+
# max_cores = 5 # would come from DaskGateway.Requirement.ResourceRequirement.max_cores
72+
# max_ram = 16 # would come from DaskGateway.Requirement.ResourceRequirement.max_ram
73+
logger.info(f"Resource limits: {max_cores} cores, {max_ram} GB RAM")
74+
75+
workers = int(min(int(max_cores) // int(worker_cores_limit), parse_memory(max_ram) // parse_memory(worker_memory)))
76+
77+
logger.info(f"Scaling cluster to {workers} workers")
78+
cluster.scale(workers)
79+
80+
81+
# save the cluster name to a file
82+
with open(target, "w") as f:
83+
f.write(cluster.name)
84+
logger.info(f"Cluster name {cluster.name} saved to {target}")

Diff for: calrissian/job.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from copy import deepcopy
12
from typing import Dict
23
from cwltool.job import ContainerCommandLineJob, needs_shell_quoting_re
34

Diff for: calrissian/k8s.py

+1-88
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from kubernetes.client.models import V1ContainerState, V1Container, V1ContainerStatus
44
from kubernetes.client.rest import ApiException
55
from kubernetes.config.config_exception import ConfigException
6+
import yaml
67
from calrissian.executor import IncompleteStatusException
78
from calrissian.retry import retry_exponential_if_exception_type
89
import threading
@@ -156,70 +157,6 @@ def follow_logs(self):
156157

157158
log.info('[{}] follow_logs end'.format(pod_name))
158159

159-
@retry_exponential_if_exception_type((ApiException, HTTPError,), log)
160-
def follow_container_logs(self, status):
161-
pod_name = self.pod.metadata.name
162-
163-
log.info('[{}] follow_logs start'.format(pod_name))
164-
for line in self.core_api_instance.read_namespaced_pod_log(self.pod.metadata.name, self.namespace, follow=True,
165-
_preload_content=False, container=status.name).stream():
166-
# .stream() is only available if _preload_content=False
167-
# .stream() returns a generator, each iteration yields bytes.
168-
# kubernetes-client decodes them as utf-8 when _preload_content is True
169-
# https://github.com/kubernetes-client/python/blob/fcda6fe96beb21cd05522c17f7f08c5a7c0e3dc3/kubernetes/client/rest.py#L215-L216
170-
# So we do the same here
171-
if not status.state.running:
172-
break
173-
line = line.decode('utf-8', errors="ignore").rstrip()
174-
log.debug('[{}] {}'.format(pod_name, line))
175-
self.tool_log.append(self.format_log_entry(pod_name, line))
176-
177-
log.info('[{}] follow_logs end'.format(pod_name))
178-
179-
@retry_exponential_if_exception_type((ApiException, HTTPError, IncompleteStatusException), log)
180-
def wait_for_dask_completion(self) -> CompletionResult:
181-
w = watch.Watch()
182-
for event in w.stream(self.core_api_instance.list_namespaced_pod, self.namespace, field_selector=self._get_pod_field_selector()):
183-
pod = event['object']
184-
# status = self.get_first_or_none(pod.status.container_statuses)
185-
last_status = self.get_last_or_none(pod.status.container_statuses)
186-
if last_status == None or not self.state_is_terminated(last_status.state):
187-
statuses = self.get_list_or_none(pod.status.container_statuses)
188-
if statuses == None:
189-
continue
190-
for status in statuses:
191-
log.info('pod name {} with id {} has status {}'.format(pod.metadata.name, pod.metadata.uid, status))
192-
if status is None:
193-
continue
194-
if self.state_is_waiting(status.state):
195-
continue
196-
elif self.state_is_running(status.state):
197-
# Can only get logs once container is running
198-
self.follow_container_logs(status) # This will not return until container completes
199-
elif self.state_is_terminated(status.state):
200-
continue
201-
else:
202-
raise CalrissianJobException('Unexpected pod container status', status)
203-
elif self.state_is_terminated(last_status.state):
204-
log.info('Handling terminated pod name {} with id {}'.format(pod.metadata.name, pod.metadata.uid))
205-
container = self.get_last_or_none(pod.spec.containers)
206-
self._handle_completion(last_status.state, container)
207-
if self.should_delete_pod():
208-
with PodMonitor() as monitor:
209-
self.delete_pod_name(pod.metadata.name)
210-
monitor.remove(pod)
211-
self._clear_pod()
212-
# stop watching for events, our pod is done. Causes wait loop to exit
213-
w.stop()
214-
else:
215-
raise CalrissianJobException('Unexpected pod container status', last_status)
216-
217-
# When the pod is done we should have a completion result
218-
# Otherwise it will lead to further exceptions
219-
if self.completion_result is None:
220-
raise IncompleteStatusException
221-
222-
return self.completion_result
223160

224161
@retry_exponential_if_exception_type((ApiException, HTTPError, IncompleteStatusException), log)
225162
def wait_for_completion(self) -> CompletionResult:
@@ -280,20 +217,6 @@ def state_is_waiting(state):
280217
def state_is_terminated(state):
281218
return state.terminated
282219

283-
@staticmethod
284-
def get_list_or_none(container_list: List[Union[V1ContainerStatus, V1Container]]) -> Union[V1ContainerStatus, V1Container]:
285-
if not container_list: # None or empty list
286-
return None
287-
else:
288-
return list(container_list)
289-
290-
@staticmethod
291-
def get_last_or_none(container_list: List[Union[V1ContainerStatus, V1Container]]) -> Union[V1ContainerStatus, V1Container]:
292-
if not container_list: # None or empty list
293-
return None
294-
else:
295-
return container_list[-1]
296-
297220
@staticmethod
298221
def get_first_or_none(container_list: List[Union[V1ContainerStatus, V1Container]]) -> Union[V1ContainerStatus, V1Container]:
299222
"""
@@ -350,16 +273,6 @@ def get_current_pod(self):
350273
raise CalrissianJobException("Missing required environment variable ${}".format(POD_NAME_ENV_VARIABLE))
351274
return self.get_pod_for_name(pod_name)
352275

353-
@retry_exponential_if_exception_type((ApiException, HTTPError,), log)
354-
def create_dask_gateway_cofig_map(self, gateway_url: str):
355-
configmap = client.V1ConfigMap(
356-
metadata=client.V1ObjectMeta(name="dask-gateway-cm"),
357-
data={
358-
"gateway.yaml": f"|\ngateway:\n\taddress:{gateway_url}"
359-
}
360-
)
361-
362-
self.core_api_instance.create_namespaced_config_map(namespace=self.namespace, body=configmap)
363276

364277
class PodMonitor(object):
365278
"""

Diff for: calrissian/main.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_log_level(parsed_args):
3434

3535

3636
def activate_logging(level):
37-
loggers = ['executor','context','tool','job', 'k8s','main']
37+
loggers = ['executor','context','tool','job', 'k8s','main', 'dask']
3838
for logger in loggers:
3939
logging.getLogger('calrissian.{}'.format(logger)).setLevel(level)
4040
logging.getLogger('calrissian.{}'.format(logger)).addHandler(logging.StreamHandler())
@@ -53,7 +53,8 @@ def add_arguments(parser):
5353
parser.add_argument('--stderr', type=Text, nargs='?', help='Output file name to tee standard error to (includes tool logs)')
5454
parser.add_argument('--tool-logs-basepath', type=Text, nargs='?', help='Base path for saving the tool logs')
5555
parser.add_argument('--conf', help='Defines the default values for the CLI arguments', action='append')
56-
parser.add_argument('--gateway-url', type=Text, nargs='?', help='Defines the Dask Gateway URL', required=False)
56+
parser.add_argument('--dask-gateway-url', type=Text, nargs='?', help='Defines the Dask Gateway URL', required=False)
57+
parser.add_argument('--dask-gateway-extra-config', type=Text, nargs='?', help='YAML file of extra k8s config for Dask', required=False)
5758

5859

5960
def print_version():
@@ -133,7 +134,7 @@ def add_custom_schema():
133134
cwltool.command_line_tool.ACCEPTLIST_RE = cwltool.command_line_tool.ACCEPTLIST_EN_RELAXED_RE
134135
supported_versions = ["v1.0", "v1.1", "v1.2"]
135136

136-
with open(os.path.join(os.path.dirname(__file__), "custom_schema/schema.yaml")) as f:
137+
with open(os.path.join(os.path.dirname(__file__), "dask/custom_schema/schema.yaml")) as f:
137138
schema_content = f.read()
138139

139140
for s in supported_versions:
@@ -160,13 +161,16 @@ def main():
160161
runtime_context = CalrissianRuntimeContext(vars(parsed_args))
161162
runtime_context.select_resources = executor.select_resources
162163
install_signal_handler()
164+
165+
parsed_args.enable_ext = True
166+
163167
try:
164168
result = cwlmain(args=parsed_args,
165169
executor=executor,
166170
loadingContext=CalrissianLoadingContext(),
167171
runtimeContext=runtime_context,
168172
versionfunc=version,
169-
custom_schema_callback=(add_custom_schema if parsed_args.gateway_url else None)
173+
custom_schema_callback=(add_custom_schema if parsed_args.dask_gateway_url else None)
170174
)
171175
finally:
172176
# Always clean up after cwlmain

0 commit comments

Comments
 (0)