@@ -96,6 +96,20 @@ def update_image_pull_secrets(spec, image_pull_secrets):
96
96
]
97
97
98
98
99
+ def update_volume_mounts (spec , volume_mounts : list ):
100
+ containers = spec .get ("containers" )
101
+ for volume_mount in volume_mounts :
102
+ for container in containers :
103
+ volumeMount = client .ApiClient ().sanitize_for_serialization (volume_mount )
104
+ container ["volumeMounts" ].append (volumeMount )
105
+
106
+
107
+ def update_volumes (spec , volumes : list ):
108
+ for volume in volumes :
109
+ new_volume = client .ApiClient ().sanitize_for_serialization (volume )
110
+ spec ["volumes" ].append (new_volume )
111
+
112
+
99
113
def update_env (spec , env ):
100
114
containers = spec .get ("containers" )
101
115
for container in containers :
@@ -136,6 +150,8 @@ def update_nodes(
136
150
head_cpus ,
137
151
head_memory ,
138
152
head_gpus ,
153
+ volumes ,
154
+ volume_mounts ,
139
155
):
140
156
head = cluster_yaml .get ("spec" ).get ("headGroupSpec" )
141
157
head ["rayStartParams" ]["num-gpus" ] = str (int (head_gpus ))
@@ -150,6 +166,8 @@ def update_nodes(
150
166
151
167
for comp in [head , worker ]:
152
168
spec = comp .get ("template" ).get ("spec" )
169
+ update_volume_mounts (spec , volume_mounts )
170
+ update_volumes (spec , volumes )
153
171
update_image_pull_secrets (spec , image_pull_secrets )
154
172
update_image (spec , image )
155
173
update_env (spec , env )
@@ -280,6 +298,8 @@ def generate_appwrapper(
280
298
write_to_file : bool ,
281
299
local_queue : Optional [str ],
282
300
labels ,
301
+ volumes : list [client .V1Volume ],
302
+ volume_mounts : list [client .V1VolumeMount ],
283
303
):
284
304
cluster_yaml = read_template (template )
285
305
appwrapper_name , cluster_name = gen_names (name )
@@ -299,6 +319,8 @@ def generate_appwrapper(
299
319
head_cpus ,
300
320
head_memory ,
301
321
head_gpus ,
322
+ volumes ,
323
+ volume_mounts ,
302
324
)
303
325
augment_labels (cluster_yaml , labels )
304
326
notebook_annotations (cluster_yaml )
0 commit comments