Skip to content

Commit bc1bdec

Browse files
authored
[core][distributed] exact ray placement control (#12732)
Signed-off-by: youkaichao <[email protected]>
1 parent 022bcc7 commit bc1bdec

File tree

6 files changed

+173
-13
lines changed

6 files changed

+173
-13
lines changed

Diff for: .buildkite/test-pipeline.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ steps:
128128
- tests/spec_decode/e2e/test_integration_dist_tp4
129129
- tests/compile
130130
- examples/offline_inference/rlhf.py
131+
- examples/offline_inference/ray_placement.py
131132
commands:
132133
- pytest -v -s distributed/test_utils.py
133134
- pytest -v -s compile/test_basic_correctness.py
@@ -136,6 +137,7 @@ steps:
136137
# TODO: create a dedicated test section for multi-GPU example tests
137138
# when we have multiple distributed example tests
138139
- python3 ../examples/offline_inference/rlhf.py
140+
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py
139141

140142
- label: Metrics, Tracing Test # 10min
141143
num_gpus: 2

Diff for: examples/offline_inference/ray_placement.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
a simple demonstration to show how to control
4+
the placement of the vLLM workers with Ray.
5+
The key is to set VLLM_RAY_PER_WORKER_GPUS and
6+
VLLM_RAY_BUNDLE_INDICES properly.
7+
"""
8+
import os
9+
10+
import ray
11+
from ray.util.placement_group import placement_group
12+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
13+
14+
from vllm import LLM
15+
from vllm.worker.worker import Worker
16+
17+
18+
class MyWorker(Worker):
19+
20+
def report_device_id(self) -> str:
21+
from vllm.platforms import current_platform
22+
return current_platform.get_device_uuid(self.device.index)
23+
24+
25+
class MyLLM(LLM):
26+
27+
def __init__(self, *args, bundle_indices: list, **kwargs):
28+
# a hack to make the script work.
29+
# stop ray from manipulating CUDA_VISIBLE_DEVICES
30+
# at the top-level
31+
del os.environ["CUDA_VISIBLE_DEVICES"]
32+
# every worker will use 0.4 GPU, so that we can schedule
33+
# 2 instances on the same GPUs.
34+
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
35+
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(
36+
map(str, bundle_indices))
37+
print(f"creating LLM with bundle_indices={bundle_indices}")
38+
super().__init__(*args, **kwargs)
39+
40+
41+
class RayTrainingActor:
42+
43+
def report_device_id(self) -> str:
44+
# the argument for get_device_uuid is the index
45+
# of the GPU in the visible devices.
46+
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
47+
from vllm.platforms import current_platform
48+
return current_platform.get_device_uuid(0)
49+
50+
51+
# ray manages 4 GPUs
52+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
53+
ray.init()
54+
55+
# we want to co-locate vLLM instance and the training actor
56+
# on the same set of GPUs.
57+
# the placement plan is as follows:
58+
# GPU 0 and 1: training actor 0, 1, and vLLM instance 0 (with TP=2)
59+
# GPU 2 and 3: training actor 2, 3, and vLLM instance 1 (with TP=2)
60+
61+
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
62+
ray.get(pg.ready())
63+
print(f"placement group has bundles {pg.bundle_specs=}")
64+
65+
training_actors = []
66+
training_actor_device_ids = []
67+
inference_engines = []
68+
inference_engine_device_ids = []
69+
70+
for bundle_index in [0, 1, 2, 3]:
71+
training_actor = ray.remote(
72+
num_cpus=0,
73+
num_gpus=0.4,
74+
scheduling_strategy=PlacementGroupSchedulingStrategy(
75+
placement_group=pg,
76+
placement_group_capture_child_tasks=True,
77+
placement_group_bundle_index=bundle_index,
78+
),
79+
)(RayTrainingActor).remote()
80+
training_actors.append(training_actor)
81+
device_id = ray.get(training_actor.report_device_id.remote())
82+
print(f"training actor {bundle_index} is on {device_id}")
83+
training_actor_device_ids.append(device_id)
84+
85+
for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
86+
# IMPORTANT: when creating vLLM instances, we need to
87+
# make sure there are no GPU activities on the target GPUs,
88+
# otherwise, they will interfere with the vLLM memory profiling,
89+
# and cause unexpected behaviors.
90+
llm = ray.remote(
91+
num_cpus=0,
92+
num_gpus=0,
93+
scheduling_strategy=PlacementGroupSchedulingStrategy(
94+
placement_group=pg,
95+
placement_group_capture_child_tasks=True,
96+
),
97+
)(MyLLM).remote(
98+
model="facebook/opt-125m",
99+
enforce_eager=True,
100+
worker_cls=MyWorker,
101+
tensor_parallel_size=2,
102+
distributed_executor_backend="ray",
103+
gpu_memory_utilization=0.4,
104+
bundle_indices=bundle_indices,
105+
)
106+
inference_engines.append(llm)
107+
# don't call any method on the inference engine here,
108+
# otherwise it will block until the vLLM instance is created.
109+
110+
for i, llm in enumerate(inference_engines):
111+
inference_engine_device_ids.append(
112+
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple())))
113+
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
114+
115+
# check the placement
116+
# the first two training actors should be
117+
# on the same GPUs as the first inference engine
118+
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
119+
# the last two training actors should be
120+
# on the same GPUs as the second inference engine
121+
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]

Diff for: vllm/envs.py

+14
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
8686
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
8787
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
88+
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
89+
VLLM_RAY_BUNDLE_INDICES: str = ""
8890

8991

9092
def get_default_cache_root():
@@ -550,6 +552,18 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
550552
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
551553
),
552554

555+
# Number of GPUs per worker in Ray, if it is set to be a fraction,
556+
# it allows ray to schedule multiple actors on a single GPU,
557+
# so that users can colocate other actors on the same GPUs as vLLM.
558+
"VLLM_RAY_PER_WORKER_GPUS":
559+
lambda: float(os.getenv("VLLM_RAY_PER_WORKER_GPUS", "1.0")),
560+
561+
# Bundle indices for Ray, if it is set, it can control precisely
562+
# which indices are used for the Ray bundle, for every worker.
563+
# Format: comma-separated list of integers, e.g. "0,1,2,3"
564+
"VLLM_RAY_BUNDLE_INDICES":
565+
lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""),
566+
553567
# When on a Nvidia GPU aligns single entries (within a page) so they are 256
554568
# byte aligned for better performance, this increases the memory usage of
555569
# the cache. Currently this only affects MLA that results in non-256

Diff for: vllm/executor/ray_distributed_executor.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,7 @@ def _get_env_vars_to_be_updated(self):
129129

130130
def _init_workers_ray(self, placement_group: "PlacementGroup",
131131
**ray_remote_kwargs):
132-
if (self.parallel_config.tensor_parallel_size == 1
133-
and self.parallel_config.pipeline_parallel_size == 1):
134-
# For single GPU case, we use a ray worker with constrained memory.
135-
num_gpus = self.cache_config.gpu_memory_utilization
136-
else:
137-
# Otherwise, the ray workers are allocated with a full GPU.
138-
num_gpus = 1
132+
num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
139133

140134
# The driver dummy worker does not actually use any resources.
141135
# It holds the resource for the driver worker.
@@ -155,12 +149,29 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
155149
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
156150

157151
# Create the workers.
158-
driver_ip = get_ip()
159-
rank = 0
152+
bundle_indices: List[int]
153+
if envs.VLLM_RAY_BUNDLE_INDICES:
154+
# Use the bundle indices specified by the user.
155+
bundle_indices = list(
156+
map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
157+
assert len(bundle_indices) == self.parallel_config.world_size, \
158+
("VLLM_RAY_BUNDLE_INDICES must have the same size"
159+
f" as the world size, but got {bundle_indices=} "
160+
f"and {self.parallel_config.world_size=}")
161+
assert len(set(bundle_indices)) == len(bundle_indices), \
162+
("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
163+
f" but got {bundle_indices=}")
164+
else:
165+
# use the first N bundles that have GPU resources.
166+
bundle_indices = []
167+
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
168+
if bundle.get(current_platform.ray_device_key, 0):
169+
bundle_indices.append(bundle_id)
170+
bundle_indices = bundle_indices[:self.parallel_config.world_size]
171+
160172
worker_metadata: List[RayWorkerMetaData] = []
161-
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
162-
if not bundle.get(current_platform.ray_device_key, 0):
163-
continue
173+
driver_ip = get_ip()
174+
for rank, bundle_id in enumerate(bundle_indices):
164175
scheduling_strategy = PlacementGroupSchedulingStrategy(
165176
placement_group=placement_group,
166177
placement_group_capture_child_tasks=True,
@@ -187,7 +198,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
187198
rpc_rank=rank)
188199
worker_metadata.append(
189200
RayWorkerMetaData(worker=worker, created_rank=rank))
190-
rank += 1
191201

192202
worker_ips = ray.get([
193203
each.worker.get_node_ip.remote() # type: ignore[attr-defined]

Diff for: vllm/platforms/cuda.py

+8
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,14 @@ def get_device_name(cls, device_id: int = 0) -> str:
275275
physical_device_id = device_id_to_physical_device_id(device_id)
276276
return cls._get_physical_device_name(physical_device_id)
277277

278+
@classmethod
279+
@lru_cache(maxsize=8)
280+
@with_nvml_context
281+
def get_device_uuid(cls, device_id: int = 0) -> str:
282+
physical_device_id = device_id_to_physical_device_id(device_id)
283+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
284+
return pynvml.nvmlDeviceGetUUID(handle)
285+
278286
@classmethod
279287
@lru_cache(maxsize=8)
280288
@with_nvml_context

Diff for: vllm/platforms/interface.py

+5
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,11 @@ def get_device_name(cls, device_id: int = 0) -> str:
183183
"""Get the name of a device."""
184184
raise NotImplementedError
185185

186+
@classmethod
187+
def get_device_uuid(cls, device_id: int = 0) -> str:
188+
"""Get the uuid of a device, e.g. the PCI bus ID."""
189+
raise NotImplementedError
190+
186191
@classmethod
187192
def get_device_total_memory(cls, device_id: int = 0) -> int:
188193
"""Get the total memory of a device in bytes."""

0 commit comments

Comments
 (0)