Skip to content

Commit c68d425

Browse files
Added custom Volumes and Volume Mounts support
1 parent 3d9ebc9 commit c68d425

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

src/codeflare_sdk/cluster/cluster.py

+4
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def create_app_wrapper(self):
152152
write_to_file = self.config.write_to_file
153153
local_queue = self.config.local_queue
154154
labels = self.config.labels
155+
volumes = self.config.volumes
156+
volume_mounts = self.config.volume_mounts
155157
return generate_appwrapper(
156158
name=name,
157159
namespace=namespace,
@@ -172,6 +174,8 @@ def create_app_wrapper(self):
172174
write_to_file=write_to_file,
173175
local_queue=local_queue,
174176
labels=labels,
177+
volumes=volumes,
178+
volume_mounts=volume_mounts,
175179
)
176180

177181
# creates a new cluster with the provided or default spec

src/codeflare_sdk/cluster/config.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class ClusterConfiguration:
5353
write_to_file: bool = False
5454
verify_tls: bool = True
5555
labels: dict = field(default_factory=dict)
56+
volumes: list = field(default_factory=list)
57+
volume_mounts: list = field(default_factory=list)
5658

5759
def __post_init__(self):
5860
if not self.verify_tls:

src/codeflare_sdk/utils/generate_yaml.py

+22
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,20 @@ def update_image_pull_secrets(spec, image_pull_secrets):
9696
]
9797

9898

99+
def update_volume_mounts(spec, volume_mounts: list):
100+
containers = spec.get("containers")
101+
for volume_mount in volume_mounts:
102+
for container in containers:
103+
volumeMount = client.ApiClient().sanitize_for_serialization(volume_mount)
104+
container["volumeMounts"].append(volumeMount)
105+
106+
107+
def update_volumes(spec, volumes: list):
108+
for volume in volumes:
109+
new_volume = client.ApiClient().sanitize_for_serialization(volume)
110+
spec["volumes"].append(new_volume)
111+
112+
99113
def update_env(spec, env):
100114
containers = spec.get("containers")
101115
for container in containers:
@@ -136,6 +150,8 @@ def update_nodes(
136150
head_cpus,
137151
head_memory,
138152
head_gpus,
153+
volumes,
154+
volume_mounts,
139155
):
140156
head = cluster_yaml.get("spec").get("headGroupSpec")
141157
head["rayStartParams"]["num-gpus"] = str(int(head_gpus))
@@ -150,6 +166,8 @@ def update_nodes(
150166

151167
for comp in [head, worker]:
152168
spec = comp.get("template").get("spec")
169+
update_volume_mounts(spec, volume_mounts)
170+
update_volumes(spec, volumes)
153171
update_image_pull_secrets(spec, image_pull_secrets)
154172
update_image(spec, image)
155173
update_env(spec, env)
@@ -280,6 +298,8 @@ def generate_appwrapper(
280298
write_to_file: bool,
281299
local_queue: Optional[str],
282300
labels,
301+
volumes: list[client.V1Volume],
302+
volume_mounts: list[client.V1VolumeMount],
283303
):
284304
cluster_yaml = read_template(template)
285305
appwrapper_name, cluster_name = gen_names(name)
@@ -299,6 +319,8 @@ def generate_appwrapper(
299319
head_cpus,
300320
head_memory,
301321
head_gpus,
322+
volumes,
323+
volume_mounts,
302324
)
303325
augment_labels(cluster_yaml, labels)
304326
notebook_annotations(cluster_yaml)

0 commit comments

Comments
 (0)