31
31
from ray .autoscaler ._private import util as ray_util
32
32
import rich .console as rich_console
33
33
import rich .progress as rich_progress
34
- import yaml
35
34
36
35
import sky
37
36
from sky import authentication as auth
@@ -120,8 +119,7 @@ def is_ip(s: str) -> bool:
120
119
def fill_template (template_name : str ,
121
120
variables : Dict ,
122
121
output_path : Optional [str ] = None ,
123
- output_prefix : str = SKY_USER_FILE_PATH ,
124
- dryrun : bool = False ) -> str :
122
+ output_prefix : str = SKY_USER_FILE_PATH ) -> str :
125
123
"""Create a file from a Jinja template and return the filename."""
126
124
assert template_name .endswith ('.j2' ), template_name
127
125
template_path = os .path .join (sky .__root_dir__ , 'templates' , template_name )
@@ -138,36 +136,44 @@ def fill_template(template_name: str,
138
136
output_path = str (output_path )
139
137
output_path = os .path .abspath (output_path )
140
138
141
- # Runtime files handling
142
- #
143
- # List of runtime files to be uploaded to cluster:
144
- # - yaml config (for autostopping)
145
- # - wheel
146
- # - credentials
147
- # Format is {dst: src}.
148
- file_mounts = {SKY_RAY_YAML_REMOTE_PATH : output_path }
149
-
150
- # fill_template() is also called to fill TPU/spot controller templates,
151
- # which don't have all variables.
152
- if 'sky_remote_path' in variables and 'sky_local_path' in variables :
153
- file_mounts [variables ['sky_remote_path' ]] = variables ['sky_local_path' ]
154
- if 'credentials' in variables :
155
- file_mounts .update (variables ['credentials' ])
139
+ # Add yaml file path to the template variables.
140
+ variables ['sky_ray_yaml_remote_path' ] = SKY_RAY_YAML_REMOTE_PATH
141
+ variables ['sky_ray_yaml_local_path' ] = output_path
142
+ # Write out yaml config.
143
+ template = jinja2 .Template (template )
144
+ content = template .render (** variables )
145
+ with open (output_path , 'w' ) as fout :
146
+ fout .write (content )
147
+ return output_path
148
+
149
+
150
+ def _optimize_file_mounts (yaml_path : str ) -> None :
151
+ """Optimize file mounts in the given ray yaml file.
152
+
153
+ Runtime files handling:
154
+ List of runtime files to be uploaded to cluster:
155
+ - yaml config (for autostopping)
156
+ - wheel
157
+ - credentials
158
+ Format is {dst: src}.
159
+ """
160
+ yaml_config = common_utils .read_yaml (yaml_path )
161
+
162
+ file_mounts = yaml_config .get ('file_mounts' , {})
163
+ # Remove the file mounts added by the newline.
164
+ if '' in file_mounts :
165
+ assert file_mounts ['' ] == '' , file_mounts ['' ]
166
+ file_mounts .pop ('' )
167
+
156
168
# Putting these in file_mounts hurts provisioning speed, as each file
157
169
# opens/closes an SSH connection. Instead, we:
158
170
# - cp locally them into a directory
159
171
# - upload that directory as a file mount (1 connection)
160
172
# - use a remote command to move all runtime files to their right places.
161
173
162
- # yaml config
163
- variables ['sky_ray_yaml_remote_path' ] = SKY_RAY_YAML_REMOTE_PATH
164
- variables ['sky_ray_yaml_local_path' ] = output_path
165
-
166
174
# Local tmp dir holding runtime files.
167
175
local_runtime_files_dir = tempfile .mkdtemp ()
168
- variables ['local_runtime_files_dir' ] = local_runtime_files_dir
169
- # Remote dir.
170
- variables ['remote_runtime_files_dir' ] = _REMOTE_RUNTIME_FILES_DIR
176
+ new_file_mounts = {_REMOTE_RUNTIME_FILES_DIR : local_runtime_files_dir }
171
177
172
178
# (For remote) Build a command that copies runtime files to their right
173
179
# destinations.
@@ -203,24 +209,28 @@ def fill_template(template_name: str,
203
209
f'{ dst_parent_dir } /{ dst_basename } ' )
204
210
fragment = f'({ mkdir_parent } && { mv } )'
205
211
commands .append (fragment )
206
- variables [ ' postprocess_runtime_files_command' ] = ' && ' .join (commands )
212
+ postprocess_runtime_files_command = ' && ' .join (commands )
207
213
208
- # Write out yaml config.
209
- template = jinja2 .Template (template )
210
- content = template .render (** variables )
211
- with open (output_path , 'w' ) as fout :
212
- fout .write (content )
214
+ setup_commands = yaml_config .get ('setup_commands' , [])
215
+ if setup_commands :
216
+ setup_commands [
217
+ 0 ] = f'{ postprocess_runtime_files_command } ; { setup_commands [0 ]} '
218
+ else :
219
+ setup_commands = [postprocess_runtime_files_command ]
220
+
221
+ yaml_config ['file_mounts' ] = new_file_mounts
222
+ yaml_config ['setup_commands' ] = setup_commands
213
223
214
224
# (For local) Move all runtime files, including the just-written yaml, to
215
225
# local_runtime_files_dir/.
216
- if not dryrun :
217
- all_local_sources = ' ' . join (
218
- local_src for local_src in file_mounts . values ())
219
- # Takes 10-20 ms on laptop incl. 3 clouds' credentials.
220
- subprocess_utils . run (
221
- f'cp -r { all_local_sources } { local_runtime_files_dir } /' )
226
+ all_local_sources = ' ' . join (
227
+ local_src for local_src in file_mounts . values ())
228
+ # Takes 10-20 ms on laptop incl. 3 clouds' credentials.
229
+ subprocess . run ( f'cp -r { all_local_sources } { local_runtime_files_dir } /' ,
230
+ shell = True ,
231
+ check = True )
222
232
223
- return output_path
233
+ common_utils . dump_yaml ( yaml_path , yaml_config )
224
234
225
235
226
236
def path_size_megabytes (path : str ) -> int :
@@ -726,13 +736,18 @@ def write_cluster_config(to_provision: 'resources.Resources',
726
736
'ssh_private_key' : (None if auth_config is None else
727
737
auth_config ['ssh_private_key' ]),
728
738
}),
729
- dryrun = dryrun ,
730
739
)
731
740
config_dict ['cluster_name' ] = cluster_name
732
741
config_dict ['ray' ] = yaml_path
733
742
if dryrun :
734
743
return config_dict
735
744
_add_auth_to_cluster_config (cloud , yaml_path )
745
+ # Delay the optimization of the config until the authentication files is added.
746
+ if not isinstance (cloud , clouds .Local ):
747
+ # Only optimize the file mounts for public clouds now, as local has not
748
+ # been fully tested yet.
749
+ _optimize_file_mounts (yaml_path )
750
+
736
751
usage_lib .messages .usage .update_ray_yaml (yaml_path )
737
752
# For TPU nodes. TPU VMs do not need TPU_NAME.
738
753
if (resources_vars .get ('tpu_type' ) is not None and
@@ -768,8 +783,7 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str):
768
783
769
784
This function's output removes comments included in the jinja2 template.
770
785
"""
771
- with open (cluster_config_file , 'r' ) as f :
772
- config = yaml .safe_load (f )
786
+ config = common_utils .read_yaml (cluster_config_file )
773
787
# Check the availability of the cloud type.
774
788
if isinstance (cloud , clouds .AWS ):
775
789
config = auth .setup_aws_authentication (config )
0 commit comments