Skip to content

Commit 93c4c12

Browse files
committed
feat: add custom volumes/volume mounts for ray clusters
1 parent 6b0a3cc commit 93c4c12

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

Diff for: src/codeflare_sdk/ray/cluster/build_ray_cluster.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers):
249249
"""
250250
pod_spec = V1PodSpec(
251251
containers=containers,
252-
volumes=VOLUMES,
252+
volumes=generate_custom_storage(cluster.config.volumes, VOLUMES),
253253
)
254254
if cluster.config.image_pull_secrets != []:
255255
pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster)
@@ -295,7 +295,9 @@ def get_head_container_spec(
295295
cluster.config.head_memory_limits,
296296
cluster.config.head_extended_resource_requests,
297297
),
298-
volume_mounts=VOLUME_MOUNTS,
298+
volume_mounts=generate_custom_storage(
299+
cluster.config.volume_mounts, VOLUME_MOUNTS
300+
),
299301
)
300302
if cluster.config.envs != {}:
301303
head_container.env = generate_env_vars(cluster)
@@ -337,7 +339,9 @@ def get_worker_container_spec(
337339
cluster.config.worker_memory_limits,
338340
cluster.config.worker_extended_resource_requests,
339341
),
340-
volume_mounts=VOLUME_MOUNTS,
342+
volume_mounts=generate_custom_storage(
343+
cluster.config.volume_mounts, VOLUME_MOUNTS
344+
),
341345
)
342346

343347
if cluster.config.envs != {}:
@@ -521,6 +525,22 @@ def wrap_cluster(
521525

522526

523527
# Etc.
528+
def generate_custom_storage(provided_storage: list, default_storage: list):
529+
"""
530+
The generate_custom_storage function updates the volumes/volume mounts configs with the default volumes/volume mounts.
531+
"""
532+
storage_list = provided_storage.copy()
533+
534+
if storage_list == []:
535+
storage_list = default_storage
536+
else:
537+
# We append the list of volumes/volume mounts with the defaults and return the full list
538+
for storage in default_storage:
539+
storage_list.append(storage)
540+
541+
return storage_list
542+
543+
524544
def write_to_file(cluster: "codeflare_sdk.ray.cluster.Cluster", resource: dict):
525545
"""
526546
The write_to_file function writes the built Ray Cluster/AppWrapper dict as a yaml file in the .codeflare folder

Diff for: src/codeflare_sdk/ray/cluster/config.py

+10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import warnings
2323
from dataclasses import dataclass, field, fields
2424
from typing import Dict, List, Optional, Union, get_args, get_origin
25+
from kubernetes.client import V1Volume, V1VolumeMount
2526

2627
dir = pathlib.Path(__file__).parent.parent.resolve()
2728

@@ -89,8 +90,15 @@ class ClusterConfiguration:
8990
A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
9091
overwrite_default_resource_mapping:
9192
A boolean indicating whether to overwrite the default resource mapping.
93+
<<<<<<< HEAD
9294
annotations:
9395
A dictionary of annotations to apply to the cluster.
96+
=======
97+
volumes:
98+
A list of V1Volume objects to add to the Cluster
99+
volume_mounts:
100+
A list of V1VolumeMount objects to add to the Cluster
101+
>>>>>>> 625b209 (feat: add custom volumes/volume mounts for ray clusters)
94102
"""
95103

96104
name: str
@@ -129,6 +137,8 @@ class ClusterConfiguration:
129137
overwrite_default_resource_mapping: bool = False
130138
local_queue: Optional[str] = None
131139
annotations: Dict[str, str] = field(default_factory=dict)
140+
volumes: list[V1Volume] = field(default_factory=list)
141+
volume_mounts: list[V1VolumeMount] = field(default_factory=list)
132142

133143
def __post_init__(self):
134144
if not self.verify_tls:

0 commit comments

Comments
 (0)