Skip to content

Commit bfa7eab

Browse files
authored
Enable custom image for AWS and GCP (skypilot-org#931)
* add image id * format * Add schema * fix resources equality * add gcp image id * retry ray up * Install pip * Remove output of pip installation * Fix conda * Add docs * fix typo * Fix aws worker image * Add tests * Fix azure and add customized image example * format * format * Address comments * fix typo * Fix output * fix ux_utils * revert minimal * fix template * Address comments * Reverted to python3 * revert parens * fix conda env
1 parent 840f655 commit bfa7eab

19 files changed

+283
-67
lines changed

docs/source/reference/yaml-spec.rst

+10
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,16 @@ describe all fields available.
7979
tpu_name: mytpu
8080
tpu_vm: False # False to use TPU nodes (the default); True to use TPU VMs.
8181
82+
# Custom image id (optional, advanced). The image id used to boot the
83+
# instances. Only supported for AWS and GCP. If not specified, sky will use
84+
# the default debian-based image suitable for machine learning tasks.
85+
# To find AWS AMI ids: https://leaherb.com/how-to-find-an-aws-marketplace-ami-image-id
86+
# AWS
87+
image_id: ami-0868a20f5a3bf9702
88+
# To find GCP images: https://cloud.google.com/compute/docs/images
89+
# GCP
90+
# image_id: projects/deeplearning-platform-release/global/images/family/tf2-ent-2-1-cpu-ubuntu-2004
91+
8292
file_mounts:
8393
# Uses rsync to copy local files to all nodes of the cluster.
8494
#

examples/custom_image.yaml

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
resources:
2+
cloud: aws
3+
region: us-east-2
4+
# Nvidia image from
5+
# https://aws.amazon.com/marketplace/pp/prodview-rf7na2b2ttvdg
6+
image_id: ami-062ddd90fb6f8267a
7+
accelerators: V100
8+
9+
setup: |
10+
echo "running setup"
11+
12+
run: |
13+
echo "hello sky"

sky/backends/backend_utils.py

+4-31
Original file line numberDiff line numberDiff line change
@@ -558,32 +558,10 @@ def write_cluster_config(to_provision: 'resources.Resources',
558558
# task.best_resources may not be equal to to_provision if the user
559559
# is running a job with less resources than the cluster has.
560560
cloud = to_provision.cloud
561-
resources_vars = cloud.make_deploy_resources_variables(to_provision)
561+
resources_vars = cloud.make_deploy_resources_variables(
562+
to_provision, region, zones)
562563
config_dict = {}
563564

564-
if region is None:
565-
assert zones is None, 'Set either both or neither for: region, zones.'
566-
region = cloud.get_default_region()
567-
zones = region.zones
568-
else:
569-
assert isinstance(
570-
cloud, clouds.Azure
571-
) or zones is not None, 'Set either both or neither for: region, zones.'
572-
region = region.name
573-
if isinstance(cloud, clouds.AWS):
574-
# Only AWS supports multiple zones in the 'availability_zone' field.
575-
zones = [zone.name for zone in zones]
576-
elif isinstance(cloud, clouds.Azure):
577-
# Azure does not support specific zones.
578-
zones = []
579-
else:
580-
zones = [zones[0].name]
581-
582-
aws_default_ami = None
583-
if isinstance(cloud, clouds.AWS):
584-
instance_type = resources_vars['instance_type']
585-
aws_default_ami = cloud.get_default_ami(region, instance_type)
586-
587565
azure_subscription_id = None
588566
if isinstance(cloud, clouds.Azure):
589567
azure_subscription_id = cloud.get_project_id(dryrun=dryrun)
@@ -594,6 +572,7 @@ def write_cluster_config(to_provision: 'resources.Resources',
594572

595573
assert cluster_name is not None
596574
credentials = sky_check.get_cloud_credential_file_mounts()
575+
region_name = resources_vars['region']
597576
yaml_path = fill_template(
598577
cluster_config_template,
599578
dict(
@@ -602,11 +581,6 @@ def write_cluster_config(to_provision: 'resources.Resources',
602581
'cluster_name': cluster_name,
603582
'num_nodes': num_nodes,
604583
'disk_size': to_provision.disk_size,
605-
# Region/zones.
606-
'region': region,
607-
'zones': ','.join(zones),
608-
# AWS only.
609-
'aws_default_ami': aws_default_ami,
610584
# Temporary measure, as deleting per-cluster SGs is too slow.
611585
# See https://github.com/sky-proj/sky/pull/742.
612586
# Generate the name of the security group we're looking for.
@@ -616,7 +590,7 @@ def write_cluster_config(to_provision: 'resources.Resources',
616590
'security_group': f'sky-sg-{user_and_hostname_hash()}',
617591
# Azure only.
618592
'azure_subscription_id': azure_subscription_id,
619-
'resource_group': f'{cluster_name}-{region}',
593+
'resource_group': f'{cluster_name}-{region_name}',
620594
# GCP only.
621595
'gcp_project_id': gcp_project_id,
622596
# Ray version.
@@ -645,7 +619,6 @@ def write_cluster_config(to_provision: 'resources.Resources',
645619
template_name,
646620
dict(
647621
resources_vars, **{
648-
'zones': ','.join(zones),
649622
'tpu_name': tpu_name,
650623
'gcp_project_id': gcp_project_id,
651624
}),

sky/backends/cloud_vm_ray_backend.py

+10
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,16 @@ def ray_up():
10191019
f'{region_name}{colorama.Style.RESET_ALL} ({zone_str})')
10201020
start = time.time()
10211021
returncode, stdout, stderr = ray_up()
1022+
if (returncode != 0 and 'Processing file mounts' in stdout and
1023+
'Running setup commands' not in stdout):
1024+
# Retry ray up if it failed due to file mounts, because it is
1025+
# probably due to too many ssh connections issue and can be fixed
1026+
# by retrying.
1027+
# This is required when using custom image for GCP.
1028+
logger.info(
1029+
'Retrying sky runtime setup due to ssh connection issue.')
1030+
returncode, stdout, stderr = ray_up()
1031+
10221032
logger.debug(f'Ray up takes {time.time() - start} seconds.')
10231033

10241034
# Only 1 node or head node provisioning failure.

sky/cli.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,11 @@ def _parse_env_var(env_var: str) -> Tuple[str, str]:
261261
default=None,
262262
help=('Whether to request spot instances. If specified, overrides the '
263263
'"resources.use_spot" config.')),
264+
click.option('--image-id',
265+
required=False,
266+
default=None,
267+
help=('Custom image id for launching the instances. '
268+
'Passing "none" resets the config.')),
264269
click.option(
265270
'--env',
266271
required=False,
@@ -552,6 +557,7 @@ def _make_dag_from_entrypoint_with_overrides(
552557
gpus: Optional[str] = None,
553558
num_nodes: Optional[int] = None,
554559
use_spot: Optional[bool] = None,
560+
image_id: Optional[str] = None,
555561
disk_size: Optional[int] = None,
556562
env: List[Dict[str, str]] = None,
557563
# spot launch specific
@@ -563,16 +569,17 @@ def _make_dag_from_entrypoint_with_overrides(
563569
if _check_yaml(entrypoint):
564570
# Treat entrypoint as a yaml.
565571
click.secho('Task from YAML spec: ', fg='yellow', nl=False)
572+
click.secho(entrypoint, bold=True)
566573
task = sky.Task.from_yaml(entrypoint)
567574
else:
568575
if not entrypoint:
569576
entrypoint = None
570577
else:
571578
# Treat entrypoint as a bash command.
572579
click.secho('Task from command: ', fg='yellow', nl=False)
580+
click.secho(entrypoint, bold=True)
573581
task = sky.Task(name='sky-cmd', run=entrypoint)
574582
task.set_resources({sky.Resources()})
575-
click.secho(entrypoint, bold=True)
576583
# Override.
577584
if workdir is not None:
578585
task.workdir = workdir
@@ -597,6 +604,11 @@ def _make_dag_from_entrypoint_with_overrides(
597604
override_params['use_spot'] = use_spot
598605
if disk_size is not None:
599606
override_params['disk_size'] = disk_size
607+
if image_id is not None:
608+
if image_id.lower() == 'none':
609+
override_params['image_id'] = None
610+
else:
611+
override_params['image_id'] = image_id
600612

601613
# Spot launch specific.
602614
if spot_recovery is not None:
@@ -735,6 +747,7 @@ def launch(
735747
gpus: Optional[str],
736748
num_nodes: Optional[int],
737749
use_spot: Optional[bool],
750+
image_id: Optional[str],
738751
env: List[Dict[str, str]],
739752
disk_size: Optional[int],
740753
idle_minutes_to_autostop: Optional[int],
@@ -763,6 +776,7 @@ def launch(
763776
gpus=gpus,
764777
num_nodes=num_nodes,
765778
use_spot=use_spot,
779+
image_id=image_id,
766780
env=env,
767781
disk_size=disk_size,
768782
)
@@ -808,6 +822,7 @@ def exec(
808822
gpus: Optional[str],
809823
num_nodes: Optional[int],
810824
use_spot: Optional[bool],
825+
image_id: Optional[str],
811826
env: List[Dict[str, str]],
812827
):
813828
"""Execute a task or a command on a cluster (skip setup).
@@ -882,6 +897,7 @@ def exec(
882897
region=region,
883898
gpus=gpus,
884899
use_spot=use_spot,
900+
image_id=image_id,
885901
num_nodes=num_nodes,
886902
env=env,
887903
)
@@ -2064,6 +2080,7 @@ def spot_launch(
20642080
gpus: Optional[str],
20652081
num_nodes: Optional[int],
20662082
use_spot: Optional[bool],
2083+
image_id: Optional[str],
20672084
spot_recovery: Optional[str],
20682085
env: List[Dict[str, str]],
20692086
disk_size: Optional[int],
@@ -2085,6 +2102,7 @@ def spot_launch(
20852102
gpus=gpus,
20862103
num_nodes=num_nodes,
20872104
use_spot=use_spot,
2105+
image_id=image_id,
20882106
env=env,
20892107
disk_size=disk_size,
20902108
spot_recovery=spot_recovery,

sky/clouds/aws.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -178,19 +178,42 @@ def get_accelerators_from_instance_type(
178178
return service_catalog.get_accelerators_from_instance_type(
179179
instance_type, clouds='aws')
180180

181-
def make_deploy_resources_variables(self,
182-
resources: 'resources_lib.Resources'):
181+
def make_deploy_resources_variables(
182+
self, resources: 'resources_lib.Resources',
183+
region: Optional['clouds.Region'],
184+
zones: Optional[List['clouds.Zone']]) -> Dict[str, str]:
185+
if region is None:
186+
assert zones is None, (
187+
'Set either both or neither for: region, zones.')
188+
region = self._get_default_region()
189+
zones = region.zones
190+
else:
191+
assert zones is not None, (
192+
'Set either both or neither for: region, zones.')
193+
194+
region_name = region.name
195+
zones = [zone.name for zone in zones]
196+
183197
r = resources
184198
# r.accelerators is cleared but .instance_type encodes the info.
185199
acc_dict = self.get_accelerators_from_instance_type(r.instance_type)
186200
if acc_dict is not None:
187201
custom_resources = json.dumps(acc_dict, separators=(',', ':'))
188202
else:
189203
custom_resources = None
204+
205+
if r.image_id is not None:
206+
image_id = r.image_id
207+
else:
208+
image_id = self.get_default_ami(region_name, r.instance_type)
209+
190210
return {
191211
'instance_type': r.instance_type,
192212
'custom_resources': custom_resources,
193213
'use_spot': r.use_spot,
214+
'region': region_name,
215+
'zones': ','.join(zones),
216+
'image_id': image_id,
194217
}
195218

196219
def get_feasible_launchable_resources(self,

sky/clouds/azure.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
import json
33
import os
44
import subprocess
5+
import typing
56
from typing import Dict, Iterator, List, Optional, Tuple
67

78
from sky import clouds
89
from sky.adaptors import azure
910
from sky.clouds import service_catalog
1011

12+
if typing.TYPE_CHECKING:
13+
from sky import resources
14+
1115
# Minimum set of files under ~/.azure that grant Azure access.
1216
_CREDENTIAL_FILES = [
1317
'azureProfile.json',
@@ -157,7 +161,19 @@ def get_accelerators_from_instance_type(
157161
return service_catalog.get_accelerators_from_instance_type(
158162
instance_type, clouds='azure')
159163

160-
def make_deploy_resources_variables(self, resources):
164+
def make_deploy_resources_variables(
165+
self, resources: 'resources.Resources',
166+
region: Optional['clouds.Region'],
167+
zones: Optional[List['clouds.Zone']]) -> Dict[str, str]:
168+
if region is None:
169+
assert zones is None, (
170+
'Set either both or neither for: region, zones.')
171+
region = self._get_default_region()
172+
173+
region_name = region.name
174+
# Azure does not support specific zones.
175+
zones = []
176+
161177
r = resources
162178
assert not r.use_spot, \
163179
'Our subscription offer ID does not support spot instances.'
@@ -175,6 +191,8 @@ def make_deploy_resources_variables(self, resources):
175191
'instance_type': r.instance_type,
176192
'custom_resources': custom_resources,
177193
'use_spot': r.use_spot,
194+
'region': region_name,
195+
'zones': zones,
178196
**image_config
179197
}
180198

sky/clouds/cloud.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Interfaces: clouds, regions, and zones."""
22
import collections
3+
import typing
34
from typing import Dict, Iterator, List, Optional, Tuple
45

6+
if typing.TYPE_CHECKING:
7+
from sky import resources
8+
59

610
class Region(collections.namedtuple('Region', ['name'])):
711
"""A region."""
@@ -107,7 +111,12 @@ def get_egress_cost(self, num_gigabytes):
107111
def is_same_cloud(self, other):
108112
raise NotImplementedError
109113

110-
def make_deploy_resources_variables(self, resources):
114+
def make_deploy_resources_variables(
115+
self,
116+
resources: 'resources.Resources',
117+
region: Optional['Region'],
118+
zones: Optional[List['Zone']],
119+
) -> Dict[str, str]:
111120
"""Converts planned sky.Resources to cloud-specific resource variables.
112121
113122
These variables are used to fill the node type section (instance type,
@@ -136,7 +145,7 @@ def get_default_instance_type(cls,
136145
raise NotImplementedError
137146

138147
@classmethod
139-
def get_default_region(cls) -> Region:
148+
def _get_default_region(cls) -> Region:
140149
raise NotImplementedError
141150

142151
def get_feasible_launchable_resources(self, resources):

0 commit comments

Comments
 (0)