Skip to content

Commit b00a359

Browse files
authored
[Custom Image] Support tag for the images and global regions (skypilot-org#1366)
* Support image tag for AWS * add gcp image support * address comments * fix * remove pandas warning * Add example for using ubuntu1804 * add ubuntu 1804 in the test * Enforce trying us regions first * format * address comments * address comments * Add docs and rename methods * Add fetch global regions for GCP * Add all regions for Azure * rename and add doc * remvoe accidently added folder * fix service_catalog * remove extra line * Address comments * mkdir for catalog path * increase waiting time in test * fix test recovery * format
1 parent c52e9d4 commit b00a359

18 files changed

+366
-118
lines changed

docs/source/reference/faq.rst

+18
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,24 @@ If you have edited the ``file_mounts`` section (e.g., by adding some files) and
7676
To avoid rerunning the ``setup`` commands, pass the ``--no-setup`` flag to ``sky launch``.
7777

7878

79+
(Advanced) How to make SkyPilot use all global regions?
80+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
81+
82+
By default, SkyPilot only supports the US regions on different clouds for convenience. If you want to utilize all global regions, please run the following command:
83+
84+
.. code-block:: bash
85+
86+
cd ~/.sky/catalogs/v4
87+
# Fetch all regions for AWS
88+
python -m sky.clouds.service_catalog.data_fetchers.fetch_aws --all-regions
89+
# Fetch all regions for GCP
90+
python -m sky.clouds.service_catalog.data_fetchers.fetch_gcp --all-regions
91+
# Fetch all regions for Azure
92+
python -m sky.clouds.service_catalog.data_fetchers.fetch_azure --all-regions
93+
94+
To make your managed spot jobs potentially use all global regions, please log into the spot controller with ``ssh sky-spot-controller-<hash>``
95+
(the full name can be found in ``sky status``), and run the commands above.
96+
7997

8098
(Advanced) How to edit or update the regions or pricing information used by SkyPilot?
8199
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

docs/source/reference/yaml-spec.rst

+5
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ Available fields:
9191
#
9292
# AWS
9393
# To find AWS AMI ids: https://leaherb.com/how-to-find-an-aws-marketplace-ami-image-id
94+
# You can also change the default OS version by choosing from the following image tags provided by SkyPilot:
95+
# image_id: skypilot:gpu-ubuntu-2004
96+
# image_id: skypilot:k80-ubuntu-2004
97+
# image_id: skypilot:gpu-ubuntu-1804
98+
# image_id: skypilot:k80-ubuntu-1804
9499
image_id: ami-0868a20f5a3bf9702
95100
# GCP
96101
# To find GCP images: https://cloud.google.com/compute/docs/images

examples/image_with_tag.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
resources:
2+
cloud: aws
3+
image_id: skypilot:gpu-ubuntu-1804
4+
5+
6+
setup: |
7+
echo "running setup"
8+
9+
run: |
10+
conda env list

sky/clouds/aws.py

+33-42
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict, Iterator, List, Optional, Tuple
1010

1111
from sky import clouds
12+
from sky import exceptions
1213
from sky.clouds import service_catalog
1314

1415
if typing.TYPE_CHECKING:
@@ -94,45 +95,41 @@ def region_zones_provision_loop(
9495
@classmethod
9596
def get_default_ami(cls, region_name: str, instance_type: str) -> str:
9697
acc = cls.get_accelerators_from_instance_type(instance_type)
98+
image_id = service_catalog.get_image_id_from_tag(
99+
'skypilot:gpu-ubuntu-2004', region_name, clouds='aws')
97100
if acc is not None:
98101
assert len(acc) == 1, acc
99102
acc_name = list(acc.keys())[0]
100103
if acc_name == 'K80':
101-
# Deep Learning AMI GPU PyTorch 1.10.0 (Ubuntu 20.04) 20211208
102-
# Downgrade the AMI for K80 due as it is only compatible with
103-
# NVIDIA driver lower than 470.
104-
amis = {
105-
'us-east-1': 'ami-0868a20f5a3bf9702',
106-
'us-east-2': 'ami-09b8825010d4dc701',
107-
# This AMI is 20210623 as aws does not provide a newer one.
108-
'us-west-1': 'ami-0b3c34d643904a734',
109-
'us-west-2': 'ami-06b3479ab15aaeaf1',
110-
}
111-
assert region_name in amis, region_name
112-
return amis[region_name]
113-
# https://console.aws.amazon.com/ec2/v2/home?region=us-east-1#Images:visibility=public-images;v=3;search=:64,:Ubuntu%2020,:Deep%20Learning%20AMI%20GPU%20PyTorch # pylint: disable=line-too-long
114-
115-
# Commented below are newer AMIs, but as other clouds do not support
116-
# torch==1.13.0+cu117 we do not use these AMIs to avoid frequent updates:
117-
# Deep Learning AMI GPU PyTorch 1.12.1 (Ubuntu 20.04) 20221025
118-
# Nvidia driver: 510.47.03, CUDA Version: 11.6 (supports torch==1.13.0+cu117)
119-
# 'us-east-1': 'ami-0eb1f91977a3fcc1b'
120-
# 'us-east-2': 'ami-0274a6db2e19b7cc6'
121-
# 'us-west-1': 'ami-0fb299af41d32cfd3'
122-
# 'us-west-2': 'ami-04ba15f9bd464eb20'
123-
#
124-
# Current AMIs:
125-
# Deep Learning AMI GPU PyTorch 1.10.0 (Ubuntu 20.04) 20220308
126-
# Nvidia driver: 510.47.03, CUDA Version: 11.6 (does not support torch==1.13.0+cu117)
127-
amis = {
128-
'us-east-1': 'ami-0729d913a335efca7',
129-
'us-east-2': 'ami-070f4af81c19b41bf',
130-
# This AMI is 20210623 as aws does not provide a newer one.
131-
'us-west-1': 'ami-0b3c34d643904a734',
132-
'us-west-2': 'ami-050814f384259894c',
133-
}
134-
assert region_name in amis, region_name
135-
return amis[region_name]
104+
image_id = service_catalog.get_image_id_from_tag(
105+
'skypilot:k80-ubuntu-2004', region_name, clouds='aws')
106+
if image_id is not None:
107+
return image_id
108+
# Raise ResourcesUnavailableError to make sure the failover in
109+
# CloudVMRayBackend will be correctly triggered.
110+
# TODO(zhwu): This is a information leakage to the cloud implementor,
111+
# we need to find a better way to handle this.
112+
raise exceptions.ResourcesUnavailableError(
113+
'No image found in catalog for region '
114+
f'{region_name}. Try setting a valid image_id.')
115+
116+
@classmethod
117+
def _get_image_id(cls, region_name: str, instance_type: str,
118+
image_id: Optional[str]) -> str:
119+
if image_id is not None:
120+
if image_id.startswith('skypilot:'):
121+
image_id = service_catalog.get_image_id_from_tag(image_id,
122+
region_name,
123+
clouds='aws')
124+
if image_id is None:
125+
# Raise ResourcesUnavailableError to make sure the failover
126+
# in CloudVMRayBackend will be correctly triggered.
127+
# TODO(zhwu): This is a information leakage to the cloud
128+
# implementor, we need to find a better way to handle this.
129+
raise exceptions.ResourcesUnavailableError(
130+
f'No image found for region {region_name}')
131+
return image_id
132+
return cls.get_default_ami(region_name, instance_type)
136133

137134
@classmethod
138135
def get_zone_shell_cmd(cls) -> Optional[str]:
@@ -232,10 +229,7 @@ def make_deploy_resources_variables(
232229
else:
233230
custom_resources = None
234231

235-
if r.image_id is not None:
236-
image_id = r.image_id
237-
else:
238-
image_id = self.get_default_ami(region_name, r.instance_type)
232+
image_id = self._get_image_id(region_name, r.instance_type, r.image_id)
239233

240234
return {
241235
'instance_type': r.instance_type,
@@ -342,9 +336,6 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
342336
def instance_type_exists(self, instance_type):
343337
return service_catalog.instance_type_exists(instance_type, clouds='aws')
344338

345-
def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
346-
return service_catalog.validate_region_zone(region, zone, clouds='aws')
347-
348339
def accelerator_in_region_or_zone(self,
349340
accelerator: str,
350341
acc_count: int,

sky/clouds/azure.py

-5
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,6 @@ def instance_type_exists(self, instance_type):
297297
return service_catalog.instance_type_exists(instance_type,
298298
clouds='azure')
299299

300-
def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
301-
return service_catalog.validate_region_zone(region,
302-
zone,
303-
clouds='azure')
304-
305300
def accelerator_in_region_or_zone(self,
306301
accelerator: str,
307302
acc_count: int,

sky/clouds/cloud.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing
44
from typing import Dict, Iterator, List, Optional, Tuple
55

6+
from sky.clouds import service_catalog
67
from sky.utils import ux_utils
78

89
if typing.TYPE_CHECKING:
@@ -162,6 +163,13 @@ def get_default_instance_type(cls) -> str:
162163
def _get_default_region(cls) -> Region:
163164
raise NotImplementedError
164165

166+
@classmethod
167+
def is_image_tag_valid(cls, image_tag: str, region: Optional[str]) -> bool:
168+
"""Validates that the image tag is valid for this cloud."""
169+
return service_catalog.is_image_tag_valid(image_tag,
170+
region,
171+
clouds=cls._REPR.lower())
172+
165173
def get_feasible_launchable_resources(self, resources):
166174
"""Returns a list of feasible and launchable resources.
167175
@@ -194,7 +202,9 @@ def instance_type_exists(self, instance_type):
194202

195203
def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
196204
"""Validates the region and zone."""
197-
raise NotImplementedError
205+
return service_catalog.validate_region_zone(region,
206+
zone,
207+
clouds=self._REPR.lower())
198208

199209
def accelerator_in_region_or_zone(self,
200210
accelerator: str,

sky/clouds/gcp.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
'active_config',
3535
]
3636

37-
_IMAGE_ID_PREFIX = ('projects/deeplearning-platform-release/global/images/')
38-
3937
_GCLOUD_INSTALLATION_LOG = '~/.sky/logs/gcloud_installation.log'
4038
# Need to be run with /bin/bash
4139
# We factor out the installation logic to keep it align in both spot
@@ -217,7 +215,10 @@ def make_deploy_resources_variables(
217215
# gcloud compute images list \
218216
# --project deeplearning-platform-release \
219217
# --no-standard-images
220-
image_id = _IMAGE_ID_PREFIX + 'common-cpu-v20220806'
218+
# We use the debian image, as the ubuntu image has some connectivity
219+
# issue when first booted.
220+
image_id = service_catalog.get_image_id_from_tag(
221+
'skypilot:cpu-debian-10', clouds='gcp')
221222

222223
r = resources
223224
# Find GPU spec, if any.
@@ -261,17 +262,20 @@ def make_deploy_resources_variables(
261262
# Though the image is called cu113, it actually has later
262263
# versions of CUDA as noted below.
263264
# CUDA driver version 470.57.02, CUDA Library 11.4
264-
image_id = _IMAGE_ID_PREFIX + 'common-cu113-v20220701'
265+
image_id = service_catalog.get_image_id_from_tag(
266+
'skypilot:k80-debian-10', clouds='gcp')
265267
else:
266268
# Though the image is called cu113, it actually has later
267269
# versions of CUDA as noted below.
268270
# CUDA driver version 510.47.03, CUDA Library 11.6
269271
# Does not support torch==1.13.0 with cu117
270-
image_id = _IMAGE_ID_PREFIX + 'common-cu113-v20220806'
272+
image_id = service_catalog.get_image_id_from_tag(
273+
'skypilot:gpu-debian-10', clouds='gcp')
271274

272275
if resources.image_id is not None:
273276
image_id = resources.image_id
274277

278+
assert image_id is not None, (image_id, r)
275279
resources_vars['image_id'] = image_id
276280

277281
return resources_vars
@@ -452,9 +456,6 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
452456
def instance_type_exists(self, instance_type):
453457
return service_catalog.instance_type_exists(instance_type, 'gcp')
454458

455-
def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
456-
return service_catalog.validate_region_zone(region, zone, clouds='gcp')
457-
458459
def accelerator_in_region_or_zone(self,
459460
accelerator: str,
460461
acc_count: int,

sky/clouds/service_catalog/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,20 @@ def get_tpus() -> List[str]:
235235
]
236236

237237

238+
def get_image_id_from_tag(tag: str,
239+
region: Optional[str] = None,
240+
clouds: CloudFilter = None) -> str:
241+
"""Returns the image ID from the tag."""
242+
return _map_clouds_catalog(clouds, 'get_image_id_from_tag', tag, region)
243+
244+
245+
def is_image_tag_valid(tag: str,
246+
region: Optional[str],
247+
clouds: CloudFilter = None) -> None:
248+
"""Validates the image tag."""
249+
return _map_clouds_catalog(clouds, 'is_image_tag_valid', tag, region)
250+
251+
238252
__all__ = [
239253
'list_accelerators',
240254
'list_accelerator_counts',
@@ -246,6 +260,9 @@ def get_tpus() -> List[str]:
246260
'get_region_zones_for_accelerators',
247261
'get_common_gpus',
248262
'get_tpus',
263+
# Images
264+
'get_image_id_from_tag',
265+
'is_image_tag_valid',
249266
# Constants
250267
'HOSTED_CATALOG_DIR_URL',
251268
'CATALOG_SCHEMA_VERSION',

sky/clouds/service_catalog/aws_catalog.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@
33
This module loads the service catalog file and can be used to query
44
instance types and pricing information for AWS.
55
"""
6+
import typing
67
from typing import Dict, List, Optional, Tuple
78

8-
from sky.clouds import cloud
99
from sky.clouds.service_catalog import common
1010

11-
_df = common.read_catalog('aws.csv')
11+
if typing.TYPE_CHECKING:
12+
from sky.clouds import cloud
13+
14+
_df = common.read_catalog('aws/vms.csv')
15+
_image_df = common.read_catalog('aws/images.csv')
1216

1317

1418
def instance_type_exists(instance_type: str) -> bool:
@@ -57,9 +61,19 @@ def get_instance_type_for_accelerator(
5761

5862

5963
def get_region_zones_for_instance_type(instance_type: str,
60-
use_spot: bool) -> List[cloud.Region]:
64+
use_spot: bool) -> List['cloud.Region']:
6165
df = _df[_df['InstanceType'] == instance_type]
62-
return common.get_region_zones(df, use_spot)
66+
region_list = common.get_region_zones(df, use_spot)
67+
# Hack: Enforce US regions are always tried first:
68+
# [US regions sorted by price] + [non-US regions sorted by price]
69+
us_region_list = []
70+
other_region_list = []
71+
for region in region_list:
72+
if region.name.startswith('us-'):
73+
us_region_list.append(region)
74+
else:
75+
other_region_list.append(region)
76+
return us_region_list + other_region_list
6377

6478

6579
def list_accelerators(gpus_only: bool,
@@ -69,3 +83,13 @@ def list_accelerators(gpus_only: bool,
6983
"""Returns all instance types in AWS offering accelerators."""
7084
return common.list_accelerators_impl('AWS', _df, gpus_only, name_filter,
7185
case_sensitive)
86+
87+
88+
def get_image_id_from_tag(tag: str, region: Optional[str]) -> Optional[str]:
89+
"""Returns the image id from the tag."""
90+
return common.get_image_id_from_tag_impl(_image_df, tag, region)
91+
92+
93+
def is_image_tag_valid(tag: str, region: Optional[str]) -> bool:
94+
"""Returns whether the image tag is valid."""
95+
return common.is_image_tag_valid_impl(_image_df, tag, region)

sky/clouds/service_catalog/azure_catalog.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
"""
66
from typing import Dict, List, Optional, Tuple
77

8-
from sky.clouds import cloud
8+
from sky import clouds as cloud_lib
99
from sky.clouds.service_catalog import common
1010
from sky.utils import ux_utils
1111

12-
_df = common.read_catalog('azure.csv')
12+
_df = common.read_catalog('azure/vms.csv')
1313

1414

1515
def instance_type_exists(instance_type: str) -> bool:
@@ -60,8 +60,8 @@ def get_instance_type_for_accelerator(
6060
acc_count=acc_count)
6161

6262

63-
def get_region_zones_for_instance_type(instance_type: str,
64-
use_spot: bool) -> List[cloud.Region]:
63+
def get_region_zones_for_instance_type(
64+
instance_type: str, use_spot: bool) -> List[cloud_lib.Region]:
6565
df = _df[_df['InstanceType'] == instance_type]
6666
return common.get_region_zones(df, use_spot)
6767

0 commit comments

Comments
 (0)