Skip to content

Commit b42e01f

Browse files
author
Vivek Khimani
authored
[sky/feat/show-gpus] adding region based filtering for show-gpus command. (skypilot-org#1187)
* [sky/feat] adding region based filtering for show-gpus command. * [tests] added test for the new region flag. * [workflows] fixed the pylint and yapf workflows. * [merge] rebase after a long time. * [fix] fixes the region column display with -a option * [fix] yapf and pylint formatting * [fix] yapf and pylint formatting * [nits] code review nits and pylint fixes.
1 parent 9fca91b commit b42e01f

File tree

7 files changed

+102
-31
lines changed

7 files changed

+102
-31
lines changed

sky/cli.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2667,8 +2667,20 @@ def check():
26672667
default=None,
26682668
type=str,
26692669
help='Cloud provider to query.')
2670+
@click.option(
2671+
'--region',
2672+
required=False,
2673+
type=str,
2674+
help=
2675+
('The region to use. If not specified, shows accelerators from all regions.'
2676+
),
2677+
)
26702678
@usage_lib.entrypoint
2671-
def show_gpus(gpu_name: Optional[str], all: bool, cloud: Optional[str]): # pylint: disable=redefined-builtin
2679+
def show_gpus(
2680+
gpu_name: Optional[str],
2681+
all: bool, # pylint: disable=redefined-builtin
2682+
cloud: Optional[str],
2683+
region: Optional[str]):
26722684
"""Show supported GPU/TPU/accelerators.
26732685
26742686
The names and counts shown can be set in the ``accelerators`` field in task
@@ -2682,9 +2694,14 @@ def show_gpus(gpu_name: Optional[str], all: bool, cloud: Optional[str]): # pyli
26822694
To show all accelerators, including less common ones and their detailed
26832695
information, use ``sky show-gpus --all``.
26842696
2685-
NOTE: The price displayed for each instance type is the lowest across all
2686-
regions for both on-demand and spot instances.
2697+
NOTE: If region is not specified, the price displayed for each instance type
2698+
is the lowest across all regions for both on-demand and spot instances.
26872699
"""
2700+
# validation for the --region flag
2701+
if region is not None and cloud is None:
2702+
raise click.UsageError(
2703+
'The --region flag is only valid when the --cloud flag is set.')
2704+
service_catalog.validate_region_zone(region, None, clouds=cloud)
26882705
show_all = all
26892706
if show_all and gpu_name is not None:
26902707
raise click.UsageError('--all is only allowed without a GPU name.')
@@ -2701,8 +2718,11 @@ def _output():
27012718
['OTHER_GPU', 'AVAILABLE_QUANTITIES'])
27022719

27032720
if gpu_name is None:
2704-
result = service_catalog.list_accelerator_counts(gpus_only=True,
2705-
clouds=cloud)
2721+
result = service_catalog.list_accelerator_counts(
2722+
gpus_only=True,
2723+
clouds=cloud,
2724+
region_filter=region,
2725+
)
27062726
# NVIDIA GPUs
27072727
for gpu in service_catalog.get_common_gpus():
27082728
if gpu in result:
@@ -2730,6 +2750,7 @@ def _output():
27302750
# Show detailed accelerator information
27312751
result = service_catalog.list_accelerators(gpus_only=True,
27322752
name_filter=gpu_name,
2753+
region_filter=region,
27332754
clouds=cloud)
27342755
if len(result) == 0:
27352756
yield f'Resources \'{gpu_name}\' not found. '
@@ -2742,7 +2763,7 @@ def _output():
27422763
yield 'the host VM\'s cost is not included.\n\n'
27432764
import pandas as pd # pylint: disable=import-outside-toplevel
27442765
for i, (gpu, items) in enumerate(result.items()):
2745-
accelerator_table = log_utils.create_table([
2766+
accelerator_table_headers = [
27462767
'GPU',
27472768
'QTY',
27482769
'CLOUD',
@@ -2751,7 +2772,11 @@ def _output():
27512772
'HOST_MEMORY',
27522773
'HOURLY_PRICE',
27532774
'HOURLY_SPOT_PRICE',
2754-
])
2775+
]
2776+
if not show_all:
2777+
accelerator_table_headers.append('REGION')
2778+
accelerator_table = log_utils.create_table(
2779+
accelerator_table_headers)
27552780
for item in items:
27562781
instance_type_str = item.instance_type if not pd.isna(
27572782
item.instance_type) else '(attachable)'
@@ -2769,11 +2794,20 @@ def _output():
27692794
item.price) else '-'
27702795
spot_price_str = f'$ {item.spot_price:.3f}' if not pd.isna(
27712796
item.spot_price) else '-'
2772-
accelerator_table.add_row([
2773-
item.accelerator_name, item.accelerator_count, item.cloud,
2774-
instance_type_str, cpu_str, mem_str, price_str,
2775-
spot_price_str
2776-
])
2797+
region_str = item.region if not pd.isna(item.region) else '-'
2798+
accelerator_table_vals = [
2799+
item.accelerator_name,
2800+
item.accelerator_count,
2801+
item.cloud,
2802+
instance_type_str,
2803+
cpu_str,
2804+
mem_str,
2805+
price_str,
2806+
spot_price_str,
2807+
]
2808+
if not show_all:
2809+
accelerator_table_vals.append(region_str)
2810+
accelerator_table.add_row(accelerator_table_vals)
27772811

27782812
if i != 0:
27792813
yield '\n\n'

sky/clouds/service_catalog/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
4949
def list_accelerators(
5050
gpus_only: bool = True,
5151
name_filter: Optional[str] = None,
52+
region_filter: Optional[str] = None,
5253
clouds: CloudFilter = None,
5354
case_sensitive: bool = True,
5455
) -> 'Dict[str, List[common.InstanceTypeInfo]]':
@@ -58,7 +59,7 @@ def list_accelerators(
5859
of instance type offerings. See usage in cli.py.
5960
"""
6061
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
61-
name_filter, case_sensitive)
62+
name_filter, region_filter, case_sensitive)
6263
if not isinstance(results, list):
6364
results = [results]
6465
ret: Dict[str,
@@ -72,6 +73,7 @@ def list_accelerators(
7273
def list_accelerator_counts(
7374
gpus_only: bool = True,
7475
name_filter: Optional[str] = None,
76+
region_filter: Optional[str] = None,
7577
clouds: CloudFilter = None,
7678
) -> Dict[str, List[int]]:
7779
"""List all accelerators offered by Sky and available counts.
@@ -80,7 +82,7 @@ def list_accelerator_counts(
8082
of available counts. See usage in cli.py.
8183
"""
8284
results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only,
83-
name_filter)
85+
name_filter, region_filter, False)
8486
if not isinstance(results, list):
8587
results = [results]
8688
accelerator_counts: Dict[str, Set[int]] = collections.defaultdict(set)

sky/clouds/service_catalog/aws_catalog.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def instance_type_exists(instance_type: str) -> bool:
7171
def validate_region_zone(
7272
region: Optional[str],
7373
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
74-
return common.validate_region_zone_impl(_df, region, zone)
74+
return common.validate_region_zone_impl('aws', _df, region, zone)
7575

7676

7777
def accelerator_in_region_or_zone(acc_name: str,
@@ -134,13 +134,15 @@ def get_region_zones_for_instance_type(instance_type: str,
134134
return us_region_list + other_region_list
135135

136136

137-
def list_accelerators(gpus_only: bool,
138-
name_filter: Optional[str],
139-
case_sensitive: bool = True
140-
) -> Dict[str, List[common.InstanceTypeInfo]]:
137+
def list_accelerators(
138+
gpus_only: bool,
139+
name_filter: Optional[str],
140+
region_filter: Optional[str],
141+
case_sensitive: bool = True
142+
) -> Dict[str, List[common.InstanceTypeInfo]]:
141143
"""Returns all instance types in AWS offering accelerators."""
142144
return common.list_accelerators_impl('AWS', _df, gpus_only, name_filter,
143-
case_sensitive)
145+
region_filter, case_sensitive)
144146

145147

146148
def get_image_id_from_tag(tag: str, region: Optional[str]) -> Optional[str]:

sky/clouds/service_catalog/azure_catalog.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def validate_region_zone(
2222
if zone is not None:
2323
with ux_utils.print_exception_no_traceback():
2424
raise ValueError('Azure does not support zones.')
25-
return common.validate_region_zone_impl(_df, region, zone)
25+
return common.validate_region_zone_impl('azure', _df, region, zone)
2626

2727

2828
def accelerator_in_region_or_zone(acc_name: str,
@@ -89,10 +89,12 @@ def get_gen_version_from_instance_type(instance_type: str) -> Optional[int]:
8989
return _df[_df['InstanceType'] == instance_type]['Generation'].iloc[0]
9090

9191

92-
def list_accelerators(gpus_only: bool,
93-
name_filter: Optional[str],
94-
case_sensitive: bool = True
95-
) -> Dict[str, List[common.InstanceTypeInfo]]:
92+
def list_accelerators(
93+
gpus_only: bool,
94+
name_filter: Optional[str],
95+
region_filter: Optional[str],
96+
case_sensitive: bool = True
97+
) -> Dict[str, List[common.InstanceTypeInfo]]:
9698
"""Returns all instance types in Azure offering GPUs."""
9799
return common.list_accelerators_impl('Azure', _df, gpus_only, name_filter,
98-
case_sensitive)
100+
region_filter, case_sensitive)

sky/clouds/service_catalog/common.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class InstanceTypeInfo(NamedTuple):
3434
- memory: Instance memory in GiB.
3535
- price: Regular instance price per hour (cheapest across all regions).
3636
- spot_price: Spot instance price per hour (cheapest across all regions).
37+
- region: Region where this instance type belongs to.
3738
"""
3839
cloud: str
3940
instance_type: Optional[str]
@@ -43,6 +44,7 @@ class InstanceTypeInfo(NamedTuple):
4344
memory: Optional[float]
4445
price: float
4546
spot_price: float
47+
region: str
4648

4749

4850
def get_catalog_path(filename: str) -> str:
@@ -161,7 +163,7 @@ def instance_type_exists_impl(df: pd.DataFrame, instance_type: str) -> bool:
161163

162164

163165
def validate_region_zone_impl(
164-
df: pd.DataFrame, region: Optional[str],
166+
cloud_name: str, df: pd.DataFrame, region: Optional[str],
165167
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
166168
"""Validates whether region and zone exist in the catalog."""
167169

@@ -174,6 +176,11 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str:
174176
candidate_strs = f'\nDid you mean one of these: {candidate_strs!r}?'
175177
return candidate_strs
176178

179+
def _get_all_supported_regions_str() -> str:
180+
all_regions: List[str] = sorted(df['Region'].unique().tolist())
181+
return \
182+
f'\nList of supported {cloud_name} regions: {", ".join(all_regions)!r}'
183+
177184
validated_region, validated_zone = region, zone
178185

179186
filter_df = df
@@ -182,7 +189,12 @@ def _get_candidate_str(loc: str, all_loc: List[str]) -> str:
182189
if len(filter_df) == 0:
183190
with ux_utils.print_exception_no_traceback():
184191
error_msg = (f'Invalid region {region!r}')
185-
error_msg += _get_candidate_str(region, df['Region'].unique())
192+
candidate_strs = _get_candidate_str(region,
193+
df['Region'].unique())
194+
if not candidate_strs:
195+
error_msg += _get_all_supported_regions_str()
196+
raise ValueError(error_msg)
197+
error_msg += candidate_strs
186198
raise ValueError(error_msg)
187199

188200
if zone is not None:
@@ -321,6 +333,7 @@ def list_accelerators_impl(
321333
df: pd.DataFrame,
322334
gpus_only: bool,
323335
name_filter: Optional[str],
336+
region_filter: Optional[str],
324337
case_sensitive: bool = True,
325338
) -> Dict[str, List[InstanceTypeInfo]]:
326339
"""Lists accelerators offered in a cloud service catalog.
@@ -341,11 +354,16 @@ def list_accelerators_impl(
341354
'MemoryGiB',
342355
'Price',
343356
'SpotPrice',
357+
'Region',
344358
]].dropna(subset=['AcceleratorName']).drop_duplicates()
345359
if name_filter is not None:
346360
df = df[df['AcceleratorName'].str.contains(name_filter,
347361
case=case_sensitive,
348362
regex=True)]
363+
if region_filter is not None:
364+
df = df[df['Region'].str.contains(region_filter,
365+
case=case_sensitive,
366+
regex=True)]
349367
df['AcceleratorCount'] = df['AcceleratorCount'].astype(int)
350368
grouped = df.groupby('AcceleratorName')
351369

@@ -366,6 +384,7 @@ def make_list_from_df(rows):
366384
row['MemoryGiB'],
367385
row['Price'],
368386
row['SpotPrice'],
387+
row['Region'],
369388
),
370389
axis='columns',
371390
).tolist()

sky/clouds/service_catalog/gcp_catalog.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def get_instance_type_for_accelerator(
210210
def validate_region_zone(
211211
region: Optional[str],
212212
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
213-
return common.validate_region_zone_impl(_df, region, zone)
213+
return common.validate_region_zone_impl('gcp', _df, region, zone)
214214

215215

216216
def accelerator_in_region_or_zone(acc_name: str,
@@ -273,11 +273,12 @@ def get_accelerator_hourly_cost(accelerator: str,
273273
def list_accelerators(
274274
gpus_only: bool,
275275
name_filter: Optional[str] = None,
276+
region_filter: Optional[str] = None,
276277
case_sensitive: bool = True,
277278
) -> Dict[str, List[common.InstanceTypeInfo]]:
278279
"""Returns all instance types in GCP offering GPUs."""
279280
results = common.list_accelerators_impl('GCP', _df, gpus_only, name_filter,
280-
case_sensitive)
281+
region_filter, case_sensitive)
281282

282283
a100_infos = results.get('A100', []) + results.get('A100-80GB', [])
283284
if not a100_infos:

tests/test_list_accelerators.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ def test_list_ccelerators_all():
1717
assert 'A100-80GB' in result, result
1818

1919

20-
def test_list_accelerators_filters():
20+
def test_list_accelerators_name_filter():
2121
result = sky.list_accelerators(gpus_only=False, name_filter='V100')
2222
assert sorted(result.keys()) == ['V100', 'V100-32GB'], result
23+
24+
25+
def test_list_accelerators_region_filter():
26+
result = sky.list_accelerators(gpus_only=False,
27+
clouds="aws",
28+
region_filter='us-west-1')
29+
all_regions = []
30+
for res in result.values():
31+
for instance in res:
32+
all_regions.append(instance.region)
33+
assert all([region == 'us-west-1' for region in all_regions])

0 commit comments

Comments
 (0)