3
3
4
4
from sky import clouds
5
5
from sky import sky_logging
6
+ from sky .clouds import service_catalog
6
7
7
8
logger = sky_logging .init_logger (__name__ )
8
9
9
10
DEFAULT_DISK_SIZE = 256
10
11
11
12
13
+ def _get_name_from_catalog (accelerator : str ) -> str :
14
+ """Returns the matched accelerator name in the catalog."""
15
+ acc_names = list (service_catalog .list_accelerators (gpus_only = False ).keys ())
16
+ try :
17
+ index = [n .casefold () for n in acc_names ].index (accelerator .casefold ())
18
+ except ValueError :
19
+ raise ValueError (f'Invalid accelerator name: { accelerator } ' ) from None
20
+ return acc_names [index ]
21
+
22
+
12
23
class Resources :
13
24
"""A cloud resource bundle.
14
25
@@ -69,7 +80,7 @@ def __init__(
69
80
assert len (accelerators ) == 1 , accelerators
70
81
71
82
acc , _ = list (accelerators .items ())[0 ]
72
- if 'tpu' in acc :
83
+ if 'tpu' in acc . lower () :
73
84
if cloud is None :
74
85
cloud = clouds .GCP ()
75
86
assert cloud .is_same_cloud (clouds .GCP ()), 'Cloud must be GCP.'
@@ -80,7 +91,7 @@ def __init__(
80
91
' default (2.5.0)' )
81
92
accelerator_args ['tf_version' ] = '2.5.0'
82
93
83
- self .accelerators = accelerators
94
+ self ._accelerators = self . _rename_accelerators ( accelerators )
84
95
self .accelerator_args = accelerator_args
85
96
86
97
self ._use_spot_specified = use_spot is not None
@@ -133,20 +144,36 @@ def _try_validate_accelerators(self) -> None:
133
144
# because e.g., the instance may have 4 GPUs, while the task
134
145
# specifies to use 1 GPU.
135
146
136
- def get_accelerators (self ) -> Optional [Dict [str , int ]]:
147
+ def _rename_accelerators (
148
+ self ,
149
+ accelerators : Union [None , Dict [str , int ]],
150
+ ) -> Optional [Dict [str , int ]]:
151
+ """Renames the accelerators in a case-sensitive manner."""
152
+ if accelerators is not None :
153
+ return {_get_name_from_catalog (name ): cnt \
154
+ for name , cnt in accelerators .items ()}
155
+ else :
156
+ return None
157
+
158
+ @property
159
+ def accelerators (self ) -> Optional [Dict [str , int ]]:
137
160
"""Returns the accelerators field directly or by inferring.
138
161
139
162
For example, Resources(AWS, 'p3.2xlarge') has its accelerators field
140
163
set to None, but this function will infer {'V100': 1} from the instance
141
164
type.
142
165
"""
143
- if self .accelerators is not None :
144
- return self .accelerators
166
+ if self ._accelerators is not None :
167
+ return self ._accelerators
145
168
if self .cloud is not None and self .instance_type is not None :
146
169
return self .cloud .get_accelerators_from_instance_type (
147
170
self .instance_type )
148
171
return None
149
172
173
+ @accelerators .setter
174
+ def accelerators (self , accelerators : Union [None , Dict [str , int ]]) -> None :
175
+ self ._accelerators = self ._rename_accelerators (accelerators )
176
+
150
177
def get_cost (self , seconds : float ):
151
178
"""Returns cost in USD for the runtime in seconds."""
152
179
hours = seconds / 3600
@@ -177,8 +204,8 @@ def is_same_resources(self, other: 'Resources') -> bool:
177
204
return False
178
205
# self.instance_type == other.instance_type
179
206
180
- other_accelerators = other .get_accelerators ()
181
- accelerators = self .get_accelerators ()
207
+ other_accelerators = other .accelerators
208
+ accelerators = self .accelerators
182
209
if accelerators != other_accelerators :
183
210
return False
184
211
# self.accelerators == other.accelerators
@@ -204,7 +231,7 @@ def less_demanding_than(self, other: 'Resources') -> bool:
204
231
return False
205
232
# self.instance_type <= other.instance_type
206
233
207
- other_accelerators = other .get_accelerators ()
234
+ other_accelerators = other .accelerators
208
235
if self .accelerators is not None and other_accelerators is None :
209
236
return False
210
237
0 commit comments