Skip to content

Commit 625b209

Browse files
committed
feat: add custom volumes/volume mounts for ray clusters
1 parent e7790c5 commit 625b209

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

src/codeflare_sdk/ray/cluster/build_ray_cluster.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers):
250250
"""
251251
pod_spec = V1PodSpec(
252252
containers=containers,
253-
volumes=VOLUMES,
253+
volumes=generate_custom_storage(cluster.config.volumes, VOLUMES),
254254
)
255255
if cluster.config.image_pull_secrets != []:
256256
pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster)
@@ -296,7 +296,9 @@ def get_head_container_spec(
296296
cluster.config.head_memory_limits,
297297
cluster.config.head_extended_resource_requests,
298298
),
299-
volume_mounts=VOLUME_MOUNTS,
299+
volume_mounts=generate_custom_storage(
300+
cluster.config.volume_mounts, VOLUME_MOUNTS
301+
),
300302
)
301303
if cluster.config.envs != {}:
302304
head_container.env = generate_env_vars(cluster)
@@ -338,7 +340,9 @@ def get_worker_container_spec(
338340
cluster.config.worker_memory_limits,
339341
cluster.config.worker_extended_resource_requests,
340342
),
341-
volume_mounts=VOLUME_MOUNTS,
343+
volume_mounts=generate_custom_storage(
344+
cluster.config.volume_mounts, VOLUME_MOUNTS
345+
),
342346
)
343347

344348
if cluster.config.envs != {}:
@@ -522,6 +526,22 @@ def wrap_cluster(
522526

523527

524528
# Etc.
529+
def generate_custom_storage(provided_storage: list, default_storage: list):
530+
"""
531+
The generate_custom_storage function updates the volumes/volume mounts configs with the default volumes/volume mounts.
532+
"""
533+
storage_list = provided_storage.copy()
534+
535+
if storage_list == []:
536+
storage_list = default_storage
537+
else:
538+
# We append the list of volumes/volume mounts with the defaults and return the full list
539+
for storage in default_storage:
540+
storage_list.append(storage)
541+
542+
return storage_list
543+
544+
525545
def write_to_file(cluster: "codeflare_sdk.ray.cluster.Cluster", resource: dict):
526546
"""
527547
The write_to_file function writes the built Ray Cluster/AppWrapper dict as a yaml file in the .codeflare folder

src/codeflare_sdk/ray/cluster/config.py

+7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import warnings
2323
from dataclasses import dataclass, field, fields
2424
from typing import Dict, List, Optional, Union, get_args, get_origin
25+
from kubernetes.client import V1Volume, V1VolumeMount
2526

2627
dir = pathlib.Path(__file__).parent.parent.resolve()
2728

@@ -89,6 +90,10 @@ class ClusterConfiguration:
8990
A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
9091
overwrite_default_resource_mapping:
9192
A boolean indicating whether to overwrite the default resource mapping.
93+
volumes:
94+
A list of V1Volume objects to add to the Cluster
95+
volume_mounts:
96+
A list of V1VolumeMount objects to add to the Cluster
9297
"""
9398

9499
name: str
@@ -126,6 +131,8 @@ class ClusterConfiguration:
126131
extended_resource_mapping: Dict[str, str] = field(default_factory=dict)
127132
overwrite_default_resource_mapping: bool = False
128133
local_queue: Optional[str] = None
134+
volumes: list[V1Volume] = field(default_factory=list)
135+
volume_mounts: list[V1VolumeMount] = field(default_factory=list)
129136

130137
def __post_init__(self):
131138
if not self.verify_tls:

0 commit comments

Comments
 (0)