16
16
from calrissian .retry import retry_exponential_if_exception_type
17
17
from calrissian .job import (
18
18
CalrissianCommandLineJob ,
19
- KubernetesPodBuilder ,
19
+ KubernetesPodBuilder
20
20
)
21
21
from calrissian .job import (
22
22
quoted_arg_list ,
23
- read_yaml
23
+ read_yaml ,
24
+ random_tag ,
25
+ k8s_safe_name ,
24
26
)
25
27
from calrissian .job import (
26
28
DEFAULT_INIT_IMAGE ,
@@ -246,8 +248,6 @@ class CalrissianCommandLineDaskJob(CalrissianCommandLineJob):
246
248
daskGateway_controller_dir = '/controller'
247
249
248
250
daskGateway_config_dir = '/etc/dask'
249
- daskGateway_cm_name = 'dask-gateway-cm'
250
- daskGateway_cm = 'dask-gateway-cm'
251
251
252
252
daskGateway_controller_cm_name = 'dask-cluster-controller-cm'
253
253
@@ -256,8 +256,14 @@ def __init__(self, *args, **kwargs):
256
256
super (CalrissianCommandLineDaskJob , self ).__init__ (* args , ** kwargs )
257
257
self .client = KubernetesDaskClient ()
258
258
259
- def wait_for_kubernetes_pod (self ):
260
- return self .client .wait_for_completion ()
259
+ self .dask_cm_name , self .dask_cm_claim_name = self .dask_configmap_name ()
260
+
261
+ def dask_configmap_name (self ):
262
+ tag = random_tag ()
263
+ return k8s_safe_name ('{}-cm-{}' .format ('dask' , tag )), k8s_safe_name ('{}-cm-{}' .format ('dask' , tag ))
264
+
265
+ def wait_for_kubernetes_pod (self , cm_name : str ):
266
+ return self .client .wait_for_completion (cm_name = cm_name )
261
267
262
268
def get_dask_gateway_url (self , runtimeContext ):
263
269
return runtimeContext .dask_gateway_url
@@ -301,7 +307,8 @@ def create_kubernetes_runtime(self, runtimeContext):
301
307
302
308
303
309
self .client .create_dask_gateway_cofig_map (
304
- dask_gateway_url = self .get_dask_gateway_url (runtimeContext ))
310
+ dask_gateway_url = self .get_dask_gateway_url (runtimeContext ),
311
+ cm_name = self .dask_cm_name )
305
312
306
313
# emptyDir volume at /shared for sharing the Dask cluster name between containers
307
314
self ._add_emptydir_volume_and_binding ('shared-data' , self .container_shared_dir )
@@ -310,8 +317,8 @@ def create_kubernetes_runtime(self, runtimeContext):
310
317
# Need this ConfigMap to simplify configuration by providing defaults,
311
318
# as explained here: https://gateway.dask.org/configuration-user.html
312
319
self ._add_configmap_volume_and_binding (
313
- name = self .daskGateway_cm ,
314
- cm_name = self .daskGateway_cm_name ,
320
+ name = self .dask_cm_name ,
321
+ cm_name = self .dask_cm_claim_name ,
315
322
target = self .daskGateway_config_dir )
316
323
317
324
@@ -375,7 +382,7 @@ def get_pod_name(pod):
375
382
376
383
pod = self .create_kubernetes_runtime (runtimeContext ) # analogous to create_runtime()
377
384
self .execute_kubernetes_pod (pod ) # analogous to _execute()
378
- completion_result = self .wait_for_kubernetes_pod ()
385
+ completion_result = self .wait_for_kubernetes_pod (cm_name = self . dask_cm_name )
379
386
if completion_result .exit_code != 0 :
380
387
log_main .error (f"ERROR the command below failed in pod { get_pod_name (pod )} :" )
381
388
log_main .error ("\t " + " " .join (get_pod_command (pod )))
@@ -387,6 +394,15 @@ class KubernetesDaskClient(KubernetesClient):
387
394
def __init__ (self ):
388
395
super ().__init__ ()
389
396
397
+ @retry_exponential_if_exception_type ((ApiException , HTTPError ,), log )
398
+ def submit_pod (self , pod_body ):
399
+ with DaskPodMonitor () as monitor :
400
+ pod = self .core_api_instance .create_namespaced_pod (self .namespace , pod_body )
401
+ log .info ('Created k8s pod name {} with id {}' .format (pod .metadata .name , pod .metadata .uid ))
402
+ monitor .add (pod )
403
+ self ._set_pod (pod )
404
+
405
+
390
406
@retry_exponential_if_exception_type ((ApiException , HTTPError ,), log )
391
407
def follow_logs (self , status ):
392
408
pod_name = self .pod .metadata .name
@@ -409,7 +425,7 @@ def follow_logs(self, status):
409
425
410
426
411
427
@retry_exponential_if_exception_type ((ApiException , HTTPError , IncompleteStatusException ), log )
412
- def wait_for_completion (self ) -> CompletionResult :
428
+ def wait_for_completion (self , cm_name : str ) -> CompletionResult :
413
429
w = watch .Watch ()
414
430
for event in w .stream (self .core_api_instance .list_namespaced_pod , self .namespace , field_selector = self ._get_pod_field_selector ()):
415
431
pod = event ['object' ]
@@ -439,7 +455,7 @@ def wait_for_completion(self) -> CompletionResult:
439
455
if self .should_delete_pod ():
440
456
with DaskPodMonitor () as monitor :
441
457
self .delete_pod_name (pod .metadata .name )
442
- self .delete_configmap_name (cm_name = "dask-gateway-cm" )
458
+ self .delete_configmap_name (cm_name = cm_name )
443
459
monitor .remove (pod )
444
460
self ._clear_pod ()
445
461
# stop watching for events, our pod is done. Causes wait loop to exit
@@ -469,11 +485,11 @@ def get_last_or_none(container_list: List[Union[V1ContainerStatus, V1Container]]
469
485
return container_list [- 1 ]
470
486
471
487
@retry_exponential_if_exception_type ((ApiException , HTTPError ,), log )
472
- def create_dask_gateway_cofig_map (self , dask_gateway_url : str ):
488
+ def create_dask_gateway_cofig_map (self , dask_gateway_url : str , cm_name : str ):
473
489
gateway = {'gateway' : {'address' : dask_gateway_url }}
474
490
475
491
configmap = client .V1ConfigMap (
476
- metadata = client .V1ObjectMeta (name = "dask-gateway-cm" ),
492
+ metadata = client .V1ObjectMeta (name = cm_name ),
477
493
data = {
478
494
"gateway.yaml" : yaml .dump (gateway )
479
495
}
@@ -518,7 +534,7 @@ def delete_configmap_name(self, cm_name):
518
534
class DaskPodMonitor (PodMonitor ):
519
535
def __init__ (self ):
520
536
super ().__init__ ()
521
-
537
+
522
538
@staticmethod
523
539
def cleanup ():
524
540
log .info ('Starting Cleanup' )
@@ -528,7 +544,6 @@ def cleanup():
528
544
log .info ('PodMonitor deleting pod {}' .format (pod_name ))
529
545
try :
530
546
k8s_client .delete_pod_name (pod_name )
531
- k8s_client .delete_configmap_name (cm_name = "dask-gateway-cm" )
532
547
except Exception :
533
548
log .error ('Error deleting pod named {}, ignoring' .format (pod_name ))
534
549
PodMonitor .pod_names = []
0 commit comments