|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +""" |
| 3 | +a simple demonstration to show how to control |
| 4 | +the placement of the vLLM workers with Ray. |
| 5 | +The key is to set VLLM_RAY_PER_WORKER_GPUS and |
| 6 | +VLLM_RAY_BUNDLE_INDICES properly. |
| 7 | +""" |
| 8 | +import os |
| 9 | + |
| 10 | +import ray |
| 11 | +from ray.util.placement_group import placement_group |
| 12 | +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy |
| 13 | + |
| 14 | +from vllm import LLM |
| 15 | +from vllm.worker.worker import Worker |
| 16 | + |
| 17 | + |
| 18 | +class MyWorker(Worker): |
| 19 | + |
| 20 | + def report_device_id(self) -> str: |
| 21 | + from vllm.platforms import current_platform |
| 22 | + return current_platform.get_device_uuid(self.device.index) |
| 23 | + |
| 24 | + |
| 25 | +class MyLLM(LLM): |
| 26 | + |
| 27 | + def __init__(self, *args, bundle_indices: list, **kwargs): |
| 28 | + # a hack to make the script work. |
| 29 | + # stop ray from manipulating CUDA_VISIBLE_DEVICES |
| 30 | + # at the top-level |
| 31 | + del os.environ["CUDA_VISIBLE_DEVICES"] |
| 32 | + # every worker will use 0.4 GPU, so that we can schedule |
| 33 | + # 2 instances on the same GPUs. |
| 34 | + os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" |
| 35 | + os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join( |
| 36 | + map(str, bundle_indices)) |
| 37 | + print(f"creating LLM with bundle_indices={bundle_indices}") |
| 38 | + super().__init__(*args, **kwargs) |
| 39 | + |
| 40 | + |
| 41 | +class RayTrainingActor: |
| 42 | + |
| 43 | + def report_device_id(self) -> str: |
| 44 | + # the argument for get_device_uuid is the index |
| 45 | + # of the GPU in the visible devices. |
| 46 | + # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs |
| 47 | + from vllm.platforms import current_platform |
| 48 | + return current_platform.get_device_uuid(0) |
| 49 | + |
| 50 | + |
| 51 | +# ray manages 4 GPUs |
| 52 | +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" |
| 53 | +ray.init() |
| 54 | + |
| 55 | +# we want to co-locate vLLM instance and the training actor |
| 56 | +# on the same set of GPUs. |
| 57 | +# the placement plan is as follows: |
| 58 | +# GPU 0 and 1: training actor 0, 1, and vLLM instance 0 (with TP=2) |
| 59 | +# GPU 2 and 3: training actor 2, 3, and vLLM instance 1 (with TP=2) |
| 60 | + |
| 61 | +pg = placement_group([{"GPU": 1, "CPU": 0}] * 4) |
| 62 | +ray.get(pg.ready()) |
| 63 | +print(f"placement group has bundles {pg.bundle_specs=}") |
| 64 | + |
| 65 | +training_actors = [] |
| 66 | +training_actor_device_ids = [] |
| 67 | +inference_engines = [] |
| 68 | +inference_engine_device_ids = [] |
| 69 | + |
| 70 | +for bundle_index in [0, 1, 2, 3]: |
| 71 | + training_actor = ray.remote( |
| 72 | + num_cpus=0, |
| 73 | + num_gpus=0.4, |
| 74 | + scheduling_strategy=PlacementGroupSchedulingStrategy( |
| 75 | + placement_group=pg, |
| 76 | + placement_group_capture_child_tasks=True, |
| 77 | + placement_group_bundle_index=bundle_index, |
| 78 | + ), |
| 79 | + )(RayTrainingActor).remote() |
| 80 | + training_actors.append(training_actor) |
| 81 | + device_id = ray.get(training_actor.report_device_id.remote()) |
| 82 | + print(f"training actor {bundle_index} is on {device_id}") |
| 83 | + training_actor_device_ids.append(device_id) |
| 84 | + |
| 85 | +for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]): |
| 86 | + # IMPORTANT: when creating vLLM instances, we need to |
| 87 | + # make sure there are no GPU activities on the target GPUs, |
| 88 | + # otherwise, they will interfere with the vLLM memory profiling, |
| 89 | + # and cause unexpected behaviors. |
| 90 | + llm = ray.remote( |
| 91 | + num_cpus=0, |
| 92 | + num_gpus=0, |
| 93 | + scheduling_strategy=PlacementGroupSchedulingStrategy( |
| 94 | + placement_group=pg, |
| 95 | + placement_group_capture_child_tasks=True, |
| 96 | + ), |
| 97 | + )(MyLLM).remote( |
| 98 | + model="facebook/opt-125m", |
| 99 | + enforce_eager=True, |
| 100 | + worker_cls=MyWorker, |
| 101 | + tensor_parallel_size=2, |
| 102 | + distributed_executor_backend="ray", |
| 103 | + gpu_memory_utilization=0.4, |
| 104 | + bundle_indices=bundle_indices, |
| 105 | + ) |
| 106 | + inference_engines.append(llm) |
| 107 | + # don't call any method on the inference engine here, |
| 108 | + # otherwise it will block until the vLLM instance is created. |
| 109 | + |
| 110 | +for i, llm in enumerate(inference_engines): |
| 111 | + inference_engine_device_ids.append( |
| 112 | + ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))) |
| 113 | + print(f"inference engine {i} is on {inference_engine_device_ids[-1]}") |
| 114 | + |
| 115 | +# check the placement |
| 116 | +# the first two training actors should be |
| 117 | +# on the same GPUs as the first inference engine |
| 118 | +assert training_actor_device_ids[:2] == inference_engine_device_ids[0] |
| 119 | +# the last two training actors should be |
| 120 | +# on the same GPUs as the second inference engine |
| 121 | +assert training_actor_device_ids[2:] == inference_engine_device_ids[1] |
0 commit comments