Skip to content

Commit 034e38b

Browse files
authored
Make GPU/TPU names case-insensitive (skypilot-org#463)
1 parent 0e49ef2 commit 034e38b

File tree

7 files changed

+43
-16
lines changed

7 files changed

+43
-16
lines changed

examples/time_estimators.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _v100(num_v100s):
4949
return _v100(num_v100s)
5050

5151
elif isinstance(resources.cloud, sky.GCP):
52-
accelerators = resources.get_accelerators()
52+
accelerators = resources.accelerators
5353
if accelerators is None:
5454
assert False, 'not supported'
5555

@@ -131,8 +131,8 @@ def resnet50_infer_estimate_runtime(resources):
131131
# TODO: this ignores offline vs. online. It's a huge batch.
132132
estimated_run_time_seconds = \
133133
flops_for_one_image * num_images / utilized_flops
134-
elif resources.get_accelerators() is not None:
135-
accs = resources.get_accelerators()
134+
elif resources.accelerators is not None:
135+
accs = resources.accelerators
136136
for acc, acc_count in accs.items():
137137
break
138138
assert acc == 'T4' and acc_count == 1, resources

sky/backends/cloud_vm_ray_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _get_task_demands_dict(task: Task) -> Optional[Tuple[Optional[str], int]]:
8888
assert len(task.resources) == 1, task.resources
8989
resources = list(task.resources)[0]
9090
if resources is not None:
91-
accelerator_dict = resources.get_accelerators()
91+
accelerator_dict = resources.accelerators
9292
return accelerator_dict
9393

9494

sky/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def _default_interactive_node_name(node_type: str):
218218

219219
def _infer_interactive_node_type(resources: sky.Resources):
220220
"""Determine interactive node type from resources."""
221-
accelerators = resources.get_accelerators()
221+
accelerators = resources.accelerators
222222
cloud = resources.cloud
223223
if accelerators:
224224
# We only support homogenous accelerators for now.

sky/clouds/aws.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _make(instance_type):
190190
return [r]
191191

192192
# Currently, handle a filter on accelerators only.
193-
accelerators = resources.get_accelerators()
193+
accelerators = resources.accelerators
194194
if accelerators is None:
195195
# No requirements to filter, so just return a default VM type.
196196
return _make(AWS.get_default_instance_type())

sky/clouds/azure.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def _make(instance_type):
182182
return [r]
183183

184184
# Currently, handle a filter on accelerators only.
185-
accelerators = resources.get_accelerators()
185+
accelerators = resources.accelerators
186186
if accelerators is None:
187187
# No requirements to filter, so just return a default VM type.
188188
return _make(Azure.get_default_instance_type())

sky/clouds/gcp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def make_deploy_resources_variables(self, resources):
177177
'custom_resources': None,
178178
'use_spot': r.use_spot,
179179
}
180-
accelerators = r.get_accelerators()
180+
accelerators = r.accelerators
181181
if accelerators is not None:
182182
assert len(accelerators) == 1, r
183183
acc, acc_count = list(accelerators.items())[0]

sky/resources.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,23 @@
33

44
from sky import clouds
55
from sky import sky_logging
6+
from sky.clouds import service_catalog
67

78
logger = sky_logging.init_logger(__name__)
89

910
DEFAULT_DISK_SIZE = 256
1011

1112

13+
def _get_name_from_catalog(accelerator: str) -> str:
14+
"""Returns the matched accelerator name in the catalog."""
15+
acc_names = list(service_catalog.list_accelerators(gpus_only=False).keys())
16+
try:
17+
index = [n.casefold() for n in acc_names].index(accelerator.casefold())
18+
except ValueError:
19+
raise ValueError(f'Invalid accelerator name: {accelerator}') from None
20+
return acc_names[index]
21+
22+
1223
class Resources:
1324
"""A cloud resource bundle.
1425
@@ -69,7 +80,7 @@ def __init__(
6980
assert len(accelerators) == 1, accelerators
7081

7182
acc, _ = list(accelerators.items())[0]
72-
if 'tpu' in acc:
83+
if 'tpu' in acc.lower():
7384
if cloud is None:
7485
cloud = clouds.GCP()
7586
assert cloud.is_same_cloud(clouds.GCP()), 'Cloud must be GCP.'
@@ -80,7 +91,7 @@ def __init__(
8091
' default (2.5.0)')
8192
accelerator_args['tf_version'] = '2.5.0'
8293

83-
self.accelerators = accelerators
94+
self._accelerators = self._rename_accelerators(accelerators)
8495
self.accelerator_args = accelerator_args
8596

8697
self._use_spot_specified = use_spot is not None
@@ -133,20 +144,36 @@ def _try_validate_accelerators(self) -> None:
133144
# because e.g., the instance may have 4 GPUs, while the task
134145
# specifies to use 1 GPU.
135146

136-
def get_accelerators(self) -> Optional[Dict[str, int]]:
147+
def _rename_accelerators(
148+
self,
149+
accelerators: Union[None, Dict[str, int]],
150+
) -> Optional[Dict[str, int]]:
151+
"""Renames the accelerators in a case-sensitive manner."""
152+
if accelerators is not None:
153+
return {_get_name_from_catalog(name): cnt \
154+
for name, cnt in accelerators.items()}
155+
else:
156+
return None
157+
158+
@property
159+
def accelerators(self) -> Optional[Dict[str, int]]:
137160
"""Returns the accelerators field directly or by inferring.
138161
139162
For example, Resources(AWS, 'p3.2xlarge') has its accelerators field
140163
set to None, but this function will infer {'V100': 1} from the instance
141164
type.
142165
"""
143-
if self.accelerators is not None:
144-
return self.accelerators
166+
if self._accelerators is not None:
167+
return self._accelerators
145168
if self.cloud is not None and self.instance_type is not None:
146169
return self.cloud.get_accelerators_from_instance_type(
147170
self.instance_type)
148171
return None
149172

173+
@accelerators.setter
174+
def accelerators(self, accelerators: Union[None, Dict[str, int]]) -> None:
175+
self._accelerators = self._rename_accelerators(accelerators)
176+
150177
def get_cost(self, seconds: float):
151178
"""Returns cost in USD for the runtime in seconds."""
152179
hours = seconds / 3600
@@ -177,8 +204,8 @@ def is_same_resources(self, other: 'Resources') -> bool:
177204
return False
178205
# self.instance_type == other.instance_type
179206

180-
other_accelerators = other.get_accelerators()
181-
accelerators = self.get_accelerators()
207+
other_accelerators = other.accelerators
208+
accelerators = self.accelerators
182209
if accelerators != other_accelerators:
183210
return False
184211
# self.accelerators == other.accelerators
@@ -204,7 +231,7 @@ def less_demanding_than(self, other: 'Resources') -> bool:
204231
return False
205232
# self.instance_type <= other.instance_type
206233

207-
other_accelerators = other.get_accelerators()
234+
other_accelerators = other.accelerators
208235
if self.accelerators is not None and other_accelerators is None:
209236
return False
210237

0 commit comments

Comments
 (0)