Skip to content

Commit fca1824

Browse files
authored
[UX] Add environment variable SKY_NUM_GPUS_PER_NODE (skypilot-org#1337)
* add SKY_NUM_GPUS_PER_NODE * increase multi-node progress timeout * pin torch version * add comment * address comment * fix smoke test * address comments
1 parent c2b2341 commit fca1824

File tree

4 files changed

+27
-8
lines changed

4 files changed

+27
-8
lines changed

docs/source/examples/distributed-jobs.rst

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ For example, here is a simple PyTorch Distributed training example:
1313
name: resnet-distributed-app
1414
1515
resources:
16-
accelerators: V100
16+
accelerators: V100:4
1717
1818
num_nodes: 2
1919
2020
setup: |
2121
pip3 install --upgrade pip
2222
git clone https://github.com/michaelzhiluo/pytorch-distributed-resnet
23-
cd pytorch-distributed-resnet && pip3 install -r requirements.txt
23+
# SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5).
24+
cd pytorch-distributed-resnet && pip3 install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
2425
mkdir -p data && mkdir -p saved_models && cd data && \
2526
wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
2627
tar -xvzf cifar-10-python.tar.gz
@@ -30,20 +31,22 @@ For example, here is a simple PyTorch Distributed training example:
3031
3132
num_nodes=`echo "$SKY_NODE_IPS" | wc -l`
3233
master_addr=`echo "$SKY_NODE_IPS" | head -n1`
33-
python3 -m torch.distributed.launch --nproc_per_node=1 \
34+
python3 -m torch.distributed.launch --nproc_per_node=$SKY_NUM_GPUS_PER_NODE \
3435
--nnodes=$num_nodes --node_rank=${SKY_NODE_RANK} --master_addr=$master_addr \
3536
--master_port=8008 resnet_ddp.py --num_epochs 20
3637
3738
In the above, :code:`num_nodes: 2` specifies that this task is to be run on 2
3839
nodes. The :code:`setup` and :code:`run` commands are executed on both nodes.
3940

40-
SkyPilot exposes two environment variables to distinguish per-node commands:
41+
SkyPilot exposes these environment variables that can be accessed in a task's ``run`` commands:
4142

4243
- :code:`SKY_NODE_RANK`: rank (an integer ID from 0 to :code:`num_nodes-1`) of
43-
the node executing the task
44+
the node executing the task.
4445
- :code:`SKY_NODE_IPS`: a string of IP addresses of the nodes reserved to execute
4546
the task, where each line contains one IP address.
4647

4748
You can retrieve the number of nodes by :code:`echo "$SKY_NODE_IPS" | wc -l`
4849
and the IP address of the third node by :code:`echo "$SKY_NODE_IPS" | sed -n
4950
3p`.
51+
- :code:`SKY_NUM_GPUS_PER_NODE`: number of GPUs reserved on each node to execute the
52+
task; the same as the count in ``accelerators: <name>:<count>`` (rounded up if a fraction).

examples/resnet_distributed_torch.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ num_nodes: 2
99
setup: |
1010
pip3 install --upgrade pip
1111
git clone https://github.com/michaelzhiluo/pytorch-distributed-resnet
12-
cd pytorch-distributed-resnet && pip3 install -r requirements.txt
12+
cd pytorch-distributed-resnet && pip3 install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
1313
mkdir -p data && mkdir -p saved_models && cd data && \
1414
wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
1515
tar -xvzf cifar-10-python.tar.gz

sky/backends/cloud_vm_ray_backend.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import enum
44
import getpass
55
import inspect
6+
import math
67
import json
78
import os
89
import pathlib
@@ -62,7 +63,7 @@
6263

6364
# Timeout (seconds) for provision progress: if in this duration no new nodes
6465
# are launched, abort and failover.
65-
_NODES_LAUNCHING_PROGRESS_TIMEOUT = 60
66+
_NODES_LAUNCHING_PROGRESS_TIMEOUT = 90
6667

6768
# Time gap between retries after failing to provision in all possible places.
6869
# Used only if --retry-until-up is set.
@@ -339,17 +340,19 @@ def add_ray_task(self,
339340
cpu_str = f', num_cpus={backend_utils.DEFAULT_TASK_CPU_DEMAND}'
340341

341342
resources_str = ''
343+
num_gpus = 0
342344
num_gpus_str = ''
343345
if ray_resources_dict is not None:
344346
assert len(ray_resources_dict) == 1, \
345347
('There can only be one type of accelerator per instance.'
346348
f' Found: {ray_resources_dict}.')
349+
num_gpus = list(ray_resources_dict.values())[0]
347350
resources_str = f', resources={json.dumps(ray_resources_dict)}'
348351

349352
# Passing this ensures that the Ray remote task gets
350353
# CUDA_VISIBLE_DEVICES set correctly. If not passed, that flag
351354
# would be force-set to empty by Ray.
352-
num_gpus_str = f', num_gpus={list(ray_resources_dict.values())[0]}'
355+
num_gpus_str = f', num_gpus={num_gpus}'
353356
# `num_gpus` should be empty when the accelerator is not GPU.
354357
# FIXME: use a set of GPU types.
355358
resources_key = list(ray_resources_dict.keys())[0]
@@ -376,6 +379,7 @@ def add_ray_task(self,
376379
log_path = os.path.expanduser({log_path!r})
377380
378381
if script is not None:
382+
sky_env_vars_dict['SKY_NUM_GPUS_PER_NODE'] = {int(math.ceil(num_gpus))!r}
379383
ip = gang_scheduling_id_to_ip[{gang_scheduling_id!r}]
380384
sky_env_vars_dict['SKY_NODE_RANK'] = ip_rank_map[ip]
381385
sky_env_vars_dict['SKY_JOB_ID'] = {self.job_id}

tests/test_smoke.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,11 @@ def test_job_queue():
294294
f'sky cancel {name} 2',
295295
'sleep 5',
296296
f'sky queue {name} | grep {name}-3 | grep RUNNING',
297+
f'sky cancel {name} 3',
298+
f'sky exec {name} --gpus K80:0.2 "[[ \$SKY_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
299+
f'sky exec {name} --gpus K80:1 "[[ \$SKY_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
300+
f'sky logs {name} 4 --status',
301+
f'sky logs {name} 5 --status',
297302
],
298303
f'sky down -y {name}',
299304
)
@@ -315,6 +320,13 @@ def test_n_node_job_queue():
315320
f'sky cancel {name} 1',
316321
'sleep 5',
317322
f'sky queue {name} | grep {name}-3 | grep RUNNING',
323+
f'sky cancel {name} 3',
324+
f'sky exec {name} --gpus K80:0.2 "[[ \$SKY_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
325+
f'sky exec {name} --gpus K80:0.2 --num-nodes 2 "[[ \$SKY_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
326+
f'sky exec {name} --gpus K80:1 --num-nodes 2 "[[ \$SKY_NUM_GPUS_PER_NODE -eq 1 ]] || exit 1"',
327+
f'sky logs {name} 4 --status',
328+
f'sky logs {name} 5 --status',
329+
f'sky logs {name} 6 --status',
318330
],
319331
f'sky down -y {name}',
320332
)

0 commit comments

Comments
 (0)