Skip to content

Commit d370176

Browse files
authored
[AWS] Update catalog from host periodically (skypilot-org#1451)
* wip * Add auto update and preferred area * Add filters for AWS * format * fix * Auto-update by default * turn off non default aws regions for global * address comments * Map AZ ID to AZ name * fix comment * fix name * fix * fix * fix path * format * Fix * wip: fix * address comments * format * Add ap-south-1 * update faq for download global regions * modify docs * update docs * fix pytest * format * fix pytest * Fetch the az mapping during checking * format * remove output for downloading mapping / catalogs * fix * fetching * increase spot time * increase timeout * fix no follow for spot * format
1 parent 4fc9a6d commit d370176

File tree

11 files changed

+281
-66
lines changed

11 files changed

+281
-66
lines changed

docs/source/reference/faq.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,17 @@ To avoid rerunning the ``setup`` commands, pass the ``--no-setup`` flag to ``sky
7979
(Advanced) How to make SkyPilot use all global regions?
8080
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8181

82-
By default, SkyPilot only supports the US regions on different clouds for convenience. If you want to utilize all global regions, please run the following command:
82+
By default, SkyPilot supports most global regions on AWS and only supports the US regions on GCP and Azure. If you want to utilize all global regions, please run the following command:
8383

8484
.. code-block:: bash
8585
86-
mkdir -p ~/.sky/catalogs/v4
87-
cd ~/.sky/catalogs/v4
88-
# Fetch all regions for AWS
89-
python -m sky.clouds.service_catalog.data_fetchers.fetch_aws --all-regions
86+
version=$(python -c 'import sky; print(sky.clouds.service_catalog.constants.CATALOG_SCHEMA_VERSION)')
87+
mkdir -p ~/.sky/catalogs/${version}
88+
cd ~/.sky/catalogs/${version}
9089
# Fetch all regions for GCP
9190
pip install lxml
9291
python -m sky.clouds.service_catalog.data_fetchers.fetch_gcp --all-regions
92+
9393
# Fetch all regions for Azure
9494
python -m sky.clouds.service_catalog.data_fetchers.fetch_azure --all-regions
9595

sky/backends/backend_utils.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""Util constants/functions for the backends."""
2-
import contextlib
32
import copy
43
from datetime import datetime
54
import difflib
@@ -1429,15 +1428,6 @@ def _process_cli_query(
14291428
]
14301429

14311430

1432-
@contextlib.contextmanager
1433-
def suppress_output():
1434-
"""Suppress stdout and stderr."""
1435-
with open(os.devnull, 'w') as devnull:
1436-
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(
1437-
devnull):
1438-
yield
1439-
1440-
14411431
def _ray_launch_hash(cluster_name: str, ray_config: Dict[str, Any]) -> Set[str]:
14421432
"""Returns a set of Ray launch config hashes, one per node type."""
14431433
# Use the cached Ray launch hashes if they exist.
@@ -1447,7 +1437,7 @@ def _ray_launch_hash(cluster_name: str, ray_config: Dict[str, Any]) -> Set[str]:
14471437
if ray_launch_hashes is not None:
14481438
logger.debug('Using cached launch_caches')
14491439
return set(ray_launch_hashes)
1450-
with suppress_output():
1440+
with ux_utils.suppress_output():
14511441
ray_config = ray_commands._bootstrap_config(ray_config) # pylint: disable=protected-access
14521442
# Adopted from https://github.com/ray-project/ray/blob/ray-2.0.1/python/ray/autoscaler/_private/node_launcher.py#L87-L97
14531443
# TODO(zhwu): this logic is duplicated from the ray code above (keep in sync).
@@ -1461,7 +1451,7 @@ def _ray_launch_hash(cluster_name: str, ray_config: Dict[str, Any]) -> Set[str]:
14611451
launch_config = copy.deepcopy(launch_config)
14621452

14631453
launch_config.update(node_config['node_config'])
1464-
with suppress_output():
1454+
with ux_utils.suppress_output():
14651455
current_hash = ray_util.hash_launch_conf(launch_config,
14661456
ray_config['auth'])
14671457
launch_hashes.add(current_hash)

sky/clouds/aws.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,9 @@ def check_credentials(self) -> Tuple[bool, Optional[str]]:
330330
' Make sure that the access and secret keys are correct.'
331331
' To reconfigure the credentials, ' + help_str[1].lower() +
332332
help_str[2:])
333+
334+
# Fetch the AWS availability zones mapping from ID to name.
335+
from sky.clouds.service_catalog import aws_catalog # pylint: disable=import-outside-toplevel,unused-import
333336
return True, None
334337

335338
def get_credential_file_mounts(self) -> Dict[str, str]:

sky/clouds/service_catalog/aws_catalog.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,62 @@
33
This module loads the service catalog file and can be used to query
44
instance types and pricing information for AWS.
55
"""
6+
import colorama
7+
import os
68
import typing
79
from typing import Dict, List, Optional, Tuple
810

11+
import pandas as pd
12+
13+
from sky import sky_logging
914
from sky.clouds.service_catalog import common
15+
from sky.utils import ux_utils
1016

1117
if typing.TYPE_CHECKING:
1218
from sky.clouds import cloud
1319

14-
_df = common.read_catalog('aws/vms.csv')
15-
_image_df = common.read_catalog('aws/images.csv')
20+
logger = sky_logging.init_logger(__name__)
21+
22+
# Keep it synced with the frequency in
23+
# skypilot-catalog/.github/workflows/update-aws-catalog.yml
24+
_PULL_FREQUENCY_HOURS = 7
25+
26+
_df = common.read_catalog('aws/vms.csv',
27+
pull_frequency_hours=_PULL_FREQUENCY_HOURS)
28+
_image_df = common.read_catalog('aws/images.csv',
29+
pull_frequency_hours=_PULL_FREQUENCY_HOURS)
30+
31+
32+
def _apply_az_mapping(df: 'pd.DataFrame') -> 'pd.DataFrame':
33+
"""Maps zone IDs (use1-az1) to zone names (us-east-1x).
34+
35+
Such mappings are account-specific and determined by AWS.
36+
37+
Returns:
38+
A dataframe with column 'AvailabilityZone' that's correctly replaced
39+
with the zone name (e.g. us-east-1a).
40+
"""
41+
az_mapping_path = common.get_catalog_path('aws/az_mappings.csv')
42+
if not os.path.exists(az_mapping_path):
43+
# Fetch az mapping from AWS.
44+
# pylint: disable=import-outside-toplevel
45+
import ray
46+
from sky.clouds.service_catalog.data_fetchers import fetch_aws
47+
logger.info(f'{colorama.Style.DIM}Fetching availability zones mapping '
48+
f'for AWS...{colorama.Style.RESET_ALL}')
49+
with ux_utils.suppress_output():
50+
ray.init()
51+
az_mappings = fetch_aws.fetch_availability_zone_mappings()
52+
az_mappings.to_csv(az_mapping_path, index=False)
53+
else:
54+
az_mappings = pd.read_csv(az_mapping_path)
55+
df = df.merge(az_mappings, on=['AvailabilityZone'], how='left')
56+
df = df.drop(columns=['AvailabilityZone']).rename(
57+
columns={'AvailabilityZoneName': 'AvailabilityZone'})
58+
return df
59+
60+
61+
_df = _apply_az_mapping(_df)
1662

1763

1864
def instance_type_exists(instance_type: str) -> bool:

sky/clouds/service_catalog/common.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Common utilities for service catalog."""
2+
import hashlib
23
import os
4+
import time
35
from typing import Dict, List, NamedTuple, Optional, Tuple
46

57
import difflib
8+
import filelock
69
import requests
710
import pandas as pd
811

@@ -46,32 +49,70 @@ def get_catalog_path(filename: str) -> str:
4649
return os.path.join(_CATALOG_DIR, filename)
4750

4851

49-
def read_catalog(filename: str) -> pd.DataFrame:
52+
def read_catalog(filename: str,
53+
pull_frequency_hours: Optional[int] = None) -> pd.DataFrame:
5054
"""Reads the catalog from a local CSV file.
5155
5256
If the file does not exist, download the up-to-date catalog that matches
5357
the schema version.
58+
If `pull_frequency_hours` is not None: pull the latest catalog with
59+
possibly updated prices, if the local catalog file is older than
60+
`pull_frequency_hours` and no changes to the local catalog file are
61+
made after the last pull.
5462
"""
5563
assert filename.endswith('.csv'), 'The catalog file must be a CSV file.'
64+
assert (pull_frequency_hours is None or
65+
pull_frequency_hours > 0), pull_frequency_hours
5666
catalog_path = get_catalog_path(filename)
5767
cloud = cloud_lib.CLOUD_REGISTRY.from_str(os.path.dirname(filename))
58-
if not os.path.exists(catalog_path):
59-
url = f'{constants.HOSTED_CATALOG_DIR_URL}/{constants.CATALOG_SCHEMA_VERSION}/{filename}' # pylint: disable=line-too-long
60-
with backend_utils.safe_console_status(
61-
f'Downloading {cloud} catalog...'):
62-
try:
63-
r = requests.get(url)
64-
r.raise_for_status()
65-
except requests.exceptions.RequestException as e:
66-
logger.error(f'Failed to download {cloud} catalog:')
67-
with ux_utils.print_exception_no_traceback():
68-
raise e
69-
# Save the catalog to a local file.
70-
os.makedirs(os.path.dirname(catalog_path), exist_ok=True)
71-
with open(catalog_path, 'w') as f:
72-
f.write(r.text)
73-
logger.info(f'A new {cloud} catalog has been downloaded to '
74-
f'{catalog_path}')
68+
69+
meta_path = os.path.join(_CATALOG_DIR, '.meta', filename)
70+
os.makedirs(os.path.dirname(meta_path), exist_ok=True)
71+
72+
# Atomic check, to avoid conflicts with other processes.
73+
# TODO(mraheja): remove pylint disabling when filelock version updated
74+
# pylint: disable=abstract-class-instantiated
75+
with filelock.FileLock(meta_path + '.lock'):
76+
77+
def _need_update() -> bool:
78+
if not os.path.exists(catalog_path):
79+
return True
80+
if pull_frequency_hours is None:
81+
return False
82+
# Check the md5 of the file to see if it has changed.
83+
with open(catalog_path, 'rb') as f:
84+
file_md5 = hashlib.md5(f.read()).hexdigest()
85+
md5_filepath = meta_path + '.md5'
86+
if os.path.exists(md5_filepath):
87+
with open(md5_filepath, 'r') as f:
88+
last_md5 = f.read()
89+
if file_md5 != last_md5:
90+
# Do not update the file if the user modified it.
91+
return False
92+
93+
last_update = os.path.getmtime(catalog_path)
94+
return last_update + pull_frequency_hours * 3600 < time.time()
95+
96+
if _need_update():
97+
url = f'{constants.HOSTED_CATALOG_DIR_URL}/{constants.CATALOG_SCHEMA_VERSION}/{filename}' # pylint: disable=line-too-long
98+
update_frequency_str = ''
99+
if pull_frequency_hours is not None:
100+
update_frequency_str = f' (every {pull_frequency_hours} hours)'
101+
with backend_utils.safe_console_status(
102+
f'Updating {cloud} catalog{update_frequency_str}'):
103+
try:
104+
r = requests.get(url)
105+
r.raise_for_status()
106+
except requests.exceptions.RequestException as e:
107+
logger.error(f'Failed to download {cloud} catalog:')
108+
with ux_utils.print_exception_no_traceback():
109+
raise e
110+
# Save the catalog to a local file.
111+
os.makedirs(os.path.dirname(catalog_path), exist_ok=True)
112+
with open(catalog_path, 'w') as f:
113+
f.write(r.text)
114+
with open(meta_path + '.md5', 'w') as f:
115+
f.write(hashlib.md5(r.text.encode()).hexdigest())
75116

76117
try:
77118
df = pd.read_csv(catalog_path)

sky/clouds/service_catalog/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
import os
33

44
HOSTED_CATALOG_DIR_URL = 'https://raw.githubusercontent.com/skypilot-org/skypilot-catalog/master/catalogs' # pylint: disable=line-too-long
5-
CATALOG_SCHEMA_VERSION = 'v4'
5+
CATALOG_SCHEMA_VERSION = 'v5'
66
LOCAL_CATALOG_DIR = os.path.expanduser('~/.sky/catalogs/')

0 commit comments

Comments
 (0)