Skip to content

Commit c52e9d4

Browse files
authored
Fix bugs in GCP A100 prices (skypilot-org#1368)
* Fix GCP A100 price bugs * yapf
1 parent 87b92af commit c52e9d4

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

sky/clouds/service_catalog/data_fetchers/fetch_gcp.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
}
4242

4343
# FIXME(woosuk): This URL can change.
44+
A2_PRICING_URL = '/compute/vm-instance-pricing_34568c2efd1858a89d6f5b0f1cdd171bbea1cdcba646e9771e6ef4028238086f.frame' # pylint: disable=line-too-long
4445
A2_INSTANCE_TYPES = {
4546
'a2-highgpu-1g': {
4647
'vCPUs': 12,
@@ -264,9 +265,6 @@ def parse_price(price_str):
264265
else:
265266
# Others (e.g., per vCPU hour or per GB hour pricing rule table).
266267
df = df[['Item', 'Region', 'Price', 'SpotPrice']]
267-
item = df['Item'].iloc[0]
268-
if item == 'Predefined vCPUs':
269-
df = get_a2_df(df)
270268
return df
271269

272270

@@ -302,12 +300,13 @@ def parse_machine_type_list(list_str):
302300
return df
303301

304302

305-
def get_a2_df(a2_pricing_df):
306-
cpu_pricing = a2_pricing_df[a2_pricing_df['Item'] == 'Predefined vCPUs']
307-
memory_pricing = a2_pricing_df[a2_pricing_df['Item'] == 'Predefined Memory']
303+
def get_a2_df():
304+
a2_pricing = get_vm_price_table(GCP_URL + A2_PRICING_URL)
305+
cpu_pricing = a2_pricing[a2_pricing['Item'] == 'Predefined vCPUs']
306+
memory_pricing = a2_pricing[a2_pricing['Item'] == 'Predefined Memory']
308307

309308
table = []
310-
for region in a2_pricing_df['Region'].unique():
309+
for region in a2_pricing['Region'].unique():
311310
per_cpu_price = cpu_pricing[cpu_pricing['Region'] ==
312311
region]['Price'].values[0]
313312
per_cpu_spot_price = cpu_pricing[cpu_pricing['Region'] ==
@@ -351,7 +350,9 @@ def get_vm_df():
351350
df for df in vm_dfs if df is not None and 'InstanceType' in df.columns
352351
]
353352

354-
vm_df = pd.concat(vm_dfs)
353+
# Handle A2 instance types separately.
354+
a2_df = get_a2_df()
355+
vm_df = pd.concat(vm_dfs + [a2_df])
355356

356357
vm_zones = get_vm_zones(GCP_VM_ZONES_URL)
357358
# Remove regions not in the pricing data.

0 commit comments

Comments
 (0)