Skip to content

Commit fcfe289

Browse files
authored
Refactor the file_mounts optimization and oslogin (skypilot-org#1108)
* Refactor the file_mounts optimization * format * format * fix * change optimization to be for all clouds * Add ray yaml back * address comments
1 parent e8c0e1a commit fcfe289

File tree

4 files changed

+73
-49
lines changed

4 files changed

+73
-49
lines changed

sky/authentication.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
PRIVATE_SSH_KEY_PATH = '~/.ssh/sky-key'
3232

3333
GCP_CONFIGURE_PATH = '~/.config/gcloud/configurations/config_default'
34-
GCP_CONFIGURE_SKY_BACKUP_PATH = '~/.config/gcloud/configurations/.sky_config_default' # pylint: disable=line-too-long
34+
# Do not place the backup under the gcloud config directory, as ray
35+
# autoscaler can overwrite that directory on the remote nodes.
36+
GCP_CONFIGURE_SKY_BACKUP_PATH = '~/.sky/.sky_gcp_config_default'
3537

3638

3739
def generate_rsa_key_pair():
@@ -202,17 +204,17 @@ def setup_gcp_authentication(config):
202204
project_oslogin = next(
203205
(item for item in project['commonInstanceMetadata'].get('items', [])
204206
if item['key'] == 'enable-oslogin'), {}).get('value', 'False')
207+
205208
if project_oslogin.lower() == 'true':
206209
# project.
207210
logger.info(
208211
f'OS Login is enabled for GCP project {project_id}. Running '
209212
'additional authentication steps.')
213+
# Read the account information from the credential file, since the user
214+
# should be set according the account, when the oslogin is enabled.
210215
config_path = os.path.expanduser(GCP_CONFIGURE_PATH)
211216
sky_backup_config_path = os.path.expanduser(
212217
GCP_CONFIGURE_SKY_BACKUP_PATH)
213-
214-
# Read the account information from the credential file, since the user
215-
# should be set according the account, when the oslogin is enabled.
216218
if not os.path.exists(sky_backup_config_path):
217219
if not os.path.exists(config_path):
218220
with ux_utils.print_exception_no_traceback():
@@ -227,7 +229,10 @@ def setup_gcp_authentication(config):
227229
subprocess.run(f'cp {config_path} {sky_backup_config_path}',
228230
shell=True,
229231
check=True)
230-
232+
new_file_mounts = config.get('file_mounts', {})
233+
new_file_mounts[
234+
GCP_CONFIGURE_SKY_BACKUP_PATH] = GCP_CONFIGURE_SKY_BACKUP_PATH
235+
config['file_mounts'] = new_file_mounts
231236
with open(sky_backup_config_path, 'r') as infile:
232237
for line in infile:
233238
if line.startswith('account'):

sky/backends/backend_utils.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from ray.autoscaler._private import util as ray_util
3232
import rich.console as rich_console
3333
import rich.progress as rich_progress
34-
import yaml
3534

3635
import sky
3736
from sky import authentication as auth
@@ -120,8 +119,7 @@ def is_ip(s: str) -> bool:
120119
def fill_template(template_name: str,
121120
variables: Dict,
122121
output_path: Optional[str] = None,
123-
output_prefix: str = SKY_USER_FILE_PATH,
124-
dryrun: bool = False) -> str:
122+
output_prefix: str = SKY_USER_FILE_PATH) -> str:
125123
"""Create a file from a Jinja template and return the filename."""
126124
assert template_name.endswith('.j2'), template_name
127125
template_path = os.path.join(sky.__root_dir__, 'templates', template_name)
@@ -138,36 +136,44 @@ def fill_template(template_name: str,
138136
output_path = str(output_path)
139137
output_path = os.path.abspath(output_path)
140138

141-
# Runtime files handling
142-
#
143-
# List of runtime files to be uploaded to cluster:
144-
# - yaml config (for autostopping)
145-
# - wheel
146-
# - credentials
147-
# Format is {dst: src}.
148-
file_mounts = {SKY_RAY_YAML_REMOTE_PATH: output_path}
149-
150-
# fill_template() is also called to fill TPU/spot controller templates,
151-
# which don't have all variables.
152-
if 'sky_remote_path' in variables and 'sky_local_path' in variables:
153-
file_mounts[variables['sky_remote_path']] = variables['sky_local_path']
154-
if 'credentials' in variables:
155-
file_mounts.update(variables['credentials'])
139+
# Add yaml file path to the template variables.
140+
variables['sky_ray_yaml_remote_path'] = SKY_RAY_YAML_REMOTE_PATH
141+
variables['sky_ray_yaml_local_path'] = output_path
142+
# Write out yaml config.
143+
template = jinja2.Template(template)
144+
content = template.render(**variables)
145+
with open(output_path, 'w') as fout:
146+
fout.write(content)
147+
return output_path
148+
149+
150+
def _optimize_file_mounts(yaml_path: str) -> None:
151+
"""Optimize file mounts in the given ray yaml file.
152+
153+
Runtime files handling:
154+
List of runtime files to be uploaded to cluster:
155+
- yaml config (for autostopping)
156+
- wheel
157+
- credentials
158+
Format is {dst: src}.
159+
"""
160+
yaml_config = common_utils.read_yaml(yaml_path)
161+
162+
file_mounts = yaml_config.get('file_mounts', {})
163+
# Remove the file mounts added by the newline.
164+
if '' in file_mounts:
165+
assert file_mounts[''] == '', file_mounts['']
166+
file_mounts.pop('')
167+
156168
# Putting these in file_mounts hurts provisioning speed, as each file
157169
# opens/closes an SSH connection. Instead, we:
158170
# - cp locally them into a directory
159171
# - upload that directory as a file mount (1 connection)
160172
# - use a remote command to move all runtime files to their right places.
161173

162-
# yaml config
163-
variables['sky_ray_yaml_remote_path'] = SKY_RAY_YAML_REMOTE_PATH
164-
variables['sky_ray_yaml_local_path'] = output_path
165-
166174
# Local tmp dir holding runtime files.
167175
local_runtime_files_dir = tempfile.mkdtemp()
168-
variables['local_runtime_files_dir'] = local_runtime_files_dir
169-
# Remote dir.
170-
variables['remote_runtime_files_dir'] = _REMOTE_RUNTIME_FILES_DIR
176+
new_file_mounts = {_REMOTE_RUNTIME_FILES_DIR: local_runtime_files_dir}
171177

172178
# (For remote) Build a command that copies runtime files to their right
173179
# destinations.
@@ -203,24 +209,28 @@ def fill_template(template_name: str,
203209
f'{dst_parent_dir}/{dst_basename}')
204210
fragment = f'({mkdir_parent} && {mv})'
205211
commands.append(fragment)
206-
variables['postprocess_runtime_files_command'] = ' && '.join(commands)
212+
postprocess_runtime_files_command = ' && '.join(commands)
207213

208-
# Write out yaml config.
209-
template = jinja2.Template(template)
210-
content = template.render(**variables)
211-
with open(output_path, 'w') as fout:
212-
fout.write(content)
214+
setup_commands = yaml_config.get('setup_commands', [])
215+
if setup_commands:
216+
setup_commands[
217+
0] = f'{postprocess_runtime_files_command}; {setup_commands[0]}'
218+
else:
219+
setup_commands = [postprocess_runtime_files_command]
220+
221+
yaml_config['file_mounts'] = new_file_mounts
222+
yaml_config['setup_commands'] = setup_commands
213223

214224
# (For local) Move all runtime files, including the just-written yaml, to
215225
# local_runtime_files_dir/.
216-
if not dryrun:
217-
all_local_sources = ' '.join(
218-
local_src for local_src in file_mounts.values())
219-
# Takes 10-20 ms on laptop incl. 3 clouds' credentials.
220-
subprocess_utils.run(
221-
f'cp -r {all_local_sources} {local_runtime_files_dir}/')
226+
all_local_sources = ' '.join(
227+
local_src for local_src in file_mounts.values())
228+
# Takes 10-20 ms on laptop incl. 3 clouds' credentials.
229+
subprocess.run(f'cp -r {all_local_sources} {local_runtime_files_dir}/',
230+
shell=True,
231+
check=True)
222232

223-
return output_path
233+
common_utils.dump_yaml(yaml_path, yaml_config)
224234

225235

226236
def path_size_megabytes(path: str) -> int:
@@ -726,13 +736,18 @@ def write_cluster_config(to_provision: 'resources.Resources',
726736
'ssh_private_key': (None if auth_config is None else
727737
auth_config['ssh_private_key']),
728738
}),
729-
dryrun=dryrun,
730739
)
731740
config_dict['cluster_name'] = cluster_name
732741
config_dict['ray'] = yaml_path
733742
if dryrun:
734743
return config_dict
735744
_add_auth_to_cluster_config(cloud, yaml_path)
745+
# Delay the optimization of the config until the authentication files is added.
746+
if not isinstance(cloud, clouds.Local):
747+
# Only optimize the file mounts for public clouds now, as local has not
748+
# been fully tested yet.
749+
_optimize_file_mounts(yaml_path)
750+
736751
usage_lib.messages.usage.update_ray_yaml(yaml_path)
737752
# For TPU nodes. TPU VMs do not need TPU_NAME.
738753
if (resources_vars.get('tpu_type') is not None and
@@ -768,8 +783,7 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str):
768783
769784
This function's output removes comments included in the jinja2 template.
770785
"""
771-
with open(cluster_config_file, 'r') as f:
772-
config = yaml.safe_load(f)
786+
config = common_utils.read_yaml(cluster_config_file)
773787
# Check the availability of the cloud type.
774788
if isinstance(cloud, clouds.AWS):
775789
config = auth.setup_aws_authentication(config)

sky/templates/gcp-ray.yml.j2

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,11 @@ head_node_type: ray_head_default
104104

105105
# Format: `REMOTE_PATH : LOCAL_PATH`
106106
file_mounts: {
107-
"{{remote_runtime_files_dir}}": "{{local_runtime_files_dir}}",
107+
"{{sky_ray_yaml_remote_path}}": "{{sky_ray_yaml_local_path}}",
108+
"{{sky_remote_path}}": "{{sky_local_path}}",
109+
{%- for remote_path, local_path in credentials.items() %}
110+
"{{remote_path}}": "{{local_path}}",
111+
{%- endfor %}
108112
}
109113

110114
rsync_exclude: []
@@ -125,8 +129,7 @@ setup_commands:
125129
# default. 'source ~/.bashrc' is needed so conda takes effect for the next
126130
# commands.
127131
# Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys`
128-
- {{postprocess_runtime_files_command}};
129-
mkdir -p ~/.ssh; touch ~/.ssh/config;
132+
- mkdir -p ~/.ssh; touch ~/.ssh/config;
130133
pip3 --version > /dev/null 2>&1 || (curl -sSL https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py && echo "PATH=$HOME/.local/bin:$PATH" >> ~/.bashrc); (type -a python | grep -q python3) || echo 'alias python=python3' >> ~/.bashrc; (type -a pip | grep -q pip3) || echo 'alias pip=pip3' >> ~/.bashrc;
131134
which conda > /dev/null 2>&1 || (wget -nc https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && bash Miniconda3-latest-Linux-x86_64.sh -b && eval "$(/home/gcpuser/miniconda3/bin/conda shell.bash hook)" && conda init && conda config --set auto_activate_base true);
132135
source ~/.bashrc;

tests/test_smoke.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def run_one_test(test: Test) -> Tuple[int, str, str]:
8282
log_file.flush()
8383
test.echo(f'Timeout after {test.timeout} seconds.')
8484
test.echo(e)
85+
log_file.write(f'Timeout after {test.timeout} seconds.\n')
86+
log_file.flush()
8587
# Kill the current process.
8688
proc.terminate()
8789
proc.returncode = 1 # None if we don't set it.

0 commit comments

Comments
 (0)