Skip to content

Commit ca834c7

Browse files
feat: add custom volumes/volume mounts for ray clusters
1 parent 6b0a3cc commit ca834c7

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

Diff for: src/codeflare_sdk/ray/cluster/build_ray_cluster.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers):
249249
"""
250250
pod_spec = V1PodSpec(
251251
containers=containers,
252-
volumes=VOLUMES,
252+
volumes=generate_custom_storage(cluster.config.volumes, VOLUMES),
253253
)
254254
if cluster.config.image_pull_secrets != []:
255255
pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster)
@@ -295,7 +295,9 @@ def get_head_container_spec(
295295
cluster.config.head_memory_limits,
296296
cluster.config.head_extended_resource_requests,
297297
),
298-
volume_mounts=VOLUME_MOUNTS,
298+
volume_mounts=generate_custom_storage(
299+
cluster.config.volume_mounts, VOLUME_MOUNTS
300+
),
299301
)
300302
if cluster.config.envs != {}:
301303
head_container.env = generate_env_vars(cluster)
@@ -337,7 +339,9 @@ def get_worker_container_spec(
337339
cluster.config.worker_memory_limits,
338340
cluster.config.worker_extended_resource_requests,
339341
),
340-
volume_mounts=VOLUME_MOUNTS,
342+
volume_mounts=generate_custom_storage(
343+
cluster.config.volume_mounts, VOLUME_MOUNTS
344+
),
341345
)
342346

343347
if cluster.config.envs != {}:
@@ -521,6 +525,22 @@ def wrap_cluster(
521525

522526

523527
# Etc.
528+
def generate_custom_storage(provided_storage: list, default_storage: list):
529+
"""
530+
The generate_custom_storage function updates the volumes/volume mounts configs with the default volumes/volume mounts.
531+
"""
532+
storage_list = provided_storage.copy()
533+
534+
if storage_list == []:
535+
storage_list = default_storage
536+
else:
537+
# We append the list of volumes/volume mounts with the defaults and return the full list
538+
for storage in default_storage:
539+
storage_list.append(storage)
540+
541+
return storage_list
542+
543+
524544
def write_to_file(cluster: "codeflare_sdk.ray.cluster.Cluster", resource: dict):
525545
"""
526546
The write_to_file function writes the built Ray Cluster/AppWrapper dict as a yaml file in the .codeflare folder

Diff for: src/codeflare_sdk/ray/cluster/config.py

+7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import warnings
2323
from dataclasses import dataclass, field, fields
2424
from typing import Dict, List, Optional, Union, get_args, get_origin
25+
from kubernetes.client import V1Volume, V1VolumeMount
2526

2627
dir = pathlib.Path(__file__).parent.parent.resolve()
2728

@@ -91,6 +92,10 @@ class ClusterConfiguration:
9192
A boolean indicating whether to overwrite the default resource mapping.
9293
annotations:
9394
A dictionary of annotations to apply to the cluster.
95+
volumes:
96+
A list of V1Volume objects to add to the Cluster
97+
volume_mounts:
98+
A list of V1VolumeMount objects to add to the Cluster
9499
"""
95100

96101
name: str
@@ -129,6 +134,8 @@ class ClusterConfiguration:
129134
overwrite_default_resource_mapping: bool = False
130135
local_queue: Optional[str] = None
131136
annotations: Dict[str, str] = field(default_factory=dict)
137+
volumes: list[V1Volume] = field(default_factory=list)
138+
volume_mounts: list[V1VolumeMount] = field(default_factory=list)
132139

133140
def __post_init__(self):
134141
if not self.verify_tls:

0 commit comments

Comments
 (0)