Skip to content

Commit fd6c335

Browse files
[Onprem] Support for Different Type of GPUs + Small Bugfix (skypilot-org#1356)
* Ok * Great suggestion from Zhanghao * fix
1 parent 75ab3de commit fd6c335

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

sky/backends/backend_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,6 @@ def fill_template(template_name: str,
159159
output_prefix)
160160
output_path = os.path.abspath(output_path)
161161

162-
# Add yaml file path to the template variables.
163-
variables['sky_ray_yaml_remote_path'] = SKY_RAY_YAML_REMOTE_PATH
164-
variables['sky_ray_yaml_local_path'] = output_path
165162
# Write out yaml config.
166163
template = jinja2.Template(template)
167164
content = template.render(**variables)
@@ -786,6 +783,11 @@ def write_cluster_config(
786783
# Sky remote utils.
787784
'sky_remote_path': SKY_REMOTE_PATH,
788785
'sky_local_path': str(local_wheel_path),
786+
# Add yaml file path to the template variables.
787+
'sky_ray_yaml_remote_path': SKY_RAY_YAML_REMOTE_PATH,
788+
'sky_ray_yaml_local_path':
789+
tmp_yaml_path
790+
if not isinstance(cloud, clouds.Local) else yaml_path,
789791
'sky_version': str(version.parse(sky.__version__)),
790792
'sky_wheel_hash': wheel_hash,
791793
# Local IP handling (optional).

sky/backends/onprem_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,11 @@ def get_local_cluster_accelerators(
266266
'T4',
267267
'P4',
268268
'K80',
269-
'A100',]
269+
'A100',
270+
'1080',
271+
'2080',
272+
'A5000'
273+
'A6000']
270274
accelerators_dict = {}
271275
for acc in all_accelerators:
272276
output_str = os.popen(f'lspci | grep \\'{acc}\\'').read()
@@ -358,9 +362,10 @@ def _stop_ray_workers(runner: command_runner.SSHCommandRunner):
358362

359363
# Launching Ray on the head node.
360364
head_resources = json.dumps(custom_resources[0], separators=(',', ':'))
365+
head_gpu_count = sum(list(custom_resources[0].values()))
361366
head_cmd = ('ray start --head --port=6379 '
362367
'--object-manager-port=8076 --dashboard-port 8265 '
363-
f'--resources={head_resources!r}')
368+
f'--resources={head_resources!r} --num-gpus={head_gpu_count}')
364369

365370
with console.status('[bold cyan]Launching ray cluster on head'):
366371
backend_utils.run_command_and_handle_ssh_failure(
@@ -399,9 +404,11 @@ def _start_ray_workers(
399404

400405
worker_resources = json.dumps(custom_resources[idx + 1],
401406
separators=(',', ':'))
407+
worker_gpu_count = sum(list(custom_resources[idx + 1].values()))
402408
worker_cmd = (f'ray start --address={head_ip}:6379 '
403409
'--object-manager-port=8076 --dashboard-port 8265 '
404-
f'--resources={worker_resources!r}')
410+
f'--resources={worker_resources!r} '
411+
f'--num-gpus={worker_gpu_count}')
405412
backend_utils.run_command_and_handle_ssh_failure(
406413
runner,
407414
worker_cmd,

sky/resources.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,12 @@ def _set_accelerators(
208208
except ValueError:
209209
with ux_utils.print_exception_no_traceback():
210210
raise ValueError(parse_error) from None
211-
assert len(accelerators) == 1, accelerators
211+
212+
# Ignore check for the local cloud case.
213+
# It is possible the accelerators dict can contain multiple
214+
# types of accelerators for some on-prem clusters.
215+
if not isinstance(self._cloud, clouds.Local):
216+
assert len(accelerators) == 1, accelerators
212217

213218
# Canonicalize the accelerator names.
214219
accelerators = {

0 commit comments

Comments
 (0)