Skip to content

[Core] LoRA: V1 Scheduler optimization #15422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Mar 25, 2025

LoRA Scheduler Optimization

Running Example:
Let max_loras be set to 4
Let Waiting queue be : { R1, R2, R3, R4-L1, R5-L3, R6-L2, R7-L4, R8-L5, R9-L5, R10-L5, R11-L5, R12, R13, R14-L1, R15-L2, R16-L3 }
Rx - Request number x. Request doesn't need any LoRA
Rx-Ly - Request number x that needs LoRA number y

Why:
In V1 + LoRA, at the moment we stop scheduling waiting requests when we can no longer honor max_loras user inputs. This is not optimal as,

  • Other requests that do not use LoRA, or,
  • Other requests that use the already requested LoRAs.
    are blocked from scheduling.

In the example above,

  • The scheduler schedules { R1, R2, R3, R4-L1, R5-L3, R6-L2, R7-L4 }
  • In future iterations the scheduler can't schedule additional requests until one of { R4-L1, R5-L3, R6-L2, R7-L4 } completes.
  • Lets say R4-L1 completes.
  • Now, the scheduler will schedule {R1, R2, R3, R5-L3, R6-L2, R7-L4, R8-L5, R9-L5, R10-L5, R11-L5, R12, R13}
  • and so on

What:
This PR updates the scheduling logic to continue scanning the waiting queue for requests that can still be scheduled.
With this PR,

  • The scheduler will schedule { R1, R2, R3, R4-L1, R5-L3, R6-L2, R7-L4, R12, R13, R14-L1, R15-L2, R16-L3 } . The requests {R8-L5, R9-L5, R10-L5, R11-L5} will stay in the waiting queue.
  • When one of the LoRA request set completes i.e. {R4-L1, R14-L1} or {R5-L3, R16-L3} or {R6-L2, R15-L2} the scheduler will continue scheduling the {R8-L5, R9-L5, R10-L5, R11-L5}

How:
We leverage the "skip waiting requests" logic introduced by structured decoding.

Benchmarks:

max_loras = 4 , number of lora modules = 8 , max_num_seqs = 256, max_num_batched_tokens = 4096

Server Command :

vllm serve  meta-llama/Llama-2-7b-hf  --enable-lora  --max-loras 4  --max-lora-rank 8  --lora-modules lora0=yard1/llama-2-7b-sql-lora-test lora1=yard1/llama-2-7b-sql-lora-test lora2=yard1/llama-2-7b-sql-lora-test lora3=yard1/llama-2-7b-sql-lora-test lora4=yard1/llama-2-7b-sql-lora-test lora5=yard1/llama-2-7b-sql-lora-test lora6=yard1/llama-2-7b-sql-lora-test lora7=yard1/llama-2-7b-sql-lora-test                       --max-num-seqs 256   --max-num-batched-tokens 4096  --no-enable-prefix-caching   --port 9002  --disable-log-stats

benchmark_serving.py command

python3 benchmarks/benchmark_serving.py   --model meta-llama/Llama-2-7b-hf                     --dataset-name sharegpt  --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json   --sharegpt-output-len 120  --num-prompts 500 --request-rate inf  --lora-modules lora0 lora1 lora2 lora3 lora4 lora5 lora6 lora7  --ignore-eos --port 9002  --seed 2

main V1

============ Serving Benchmark Result ============
Successful requests:                     500       
Benchmark duration (s):                  200.71    
Total input tokens:                      124520    
Total generated tokens:                  60000     
Request throughput (req/s):              2.49      
Output token throughput (tok/s):         298.94    
Total Token throughput (tok/s):          919.35    
---------------Time to First Token----------------
Mean TTFT (ms):                          99070.16  
Median TTFT (ms):                        101450.02 
P99 TTFT (ms):                           195994.95 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          15.13     
Median TPOT (ms):                        15.24     
P99 TPOT (ms):                           16.17     
---------------Inter-token Latency----------------
Mean ITL (ms):                           15.31     
Median ITL (ms):                         15.22     
P99 ITL (ms):                            16.88     
==================================================

main V0

============ Serving Benchmark Result ============
Successful requests:                     500       
Benchmark duration (s):                  29.19     
Total input tokens:                      124520    
Total generated tokens:                  60000     
Request throughput (req/s):              17.13     
Output token throughput (tok/s):         2055.28   
Total Token throughput (tok/s):          6320.66   
---------------Time to First Token----------------
Mean TTFT (ms):                          10400.52  
Median TTFT (ms):                        6083.83   
P99 TTFT (ms):                           26071.17  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          79.85     
Median TPOT (ms):                        80.54     
P99 TPOT (ms):                           106.48    
---------------Inter-token Latency----------------
Mean ITL (ms):                           79.85     
Median ITL (ms):                         61.96     
P99 ITL (ms):                            389.40    
==================================================

This PR

============ Serving Benchmark Result ============
Successful requests:                     500       
Benchmark duration (s):                  29.84     
Total input tokens:                      124520    
Total generated tokens:                  60000     
Request throughput (req/s):              16.75     
Output token throughput (tok/s):         2010.46   
Total Token throughput (tok/s):          6182.82   
---------------Time to First Token----------------
Mean TTFT (ms):                          11862.28  
Median TTFT (ms):                        6406.40   
P99 TTFT (ms):                           24552.50  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          73.55     
Median TPOT (ms):                        77.75     
P99 TPOT (ms):                           102.69    
---------------Inter-token Latency----------------
Mean ITL (ms):                           73.55     
Median ITL (ms):                         50.22     
P99 ITL (ms):                            346.09    
==================================================

Varun Sundar Rabindranath added 3 commits March 24, 2025 19:58
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 25, 2025
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM

@varun-sundar-rabindranath
Copy link
Contributor Author

Requesting review from @russellb for the changes to the "Structured Outputs" land! 🙌

@varun-sundar-rabindranath
Copy link
Contributor Author

Requesting reviews from @jeejeelee 🙌

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/lora-scheduler-optimization branch from 1ca090c to 7675df8 Compare March 25, 2025 01:07
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides the variable name, the PR looks good to me!
@varun-sundar-rabindranath Do you have any performance numbers after this PR?

@varun-sundar-rabindranath
Copy link
Contributor Author

Besides the variable name, the PR looks good to me! @varun-sundar-rabindranath Do you have any performance numbers after this PR?

I am running some benchmarks now. Ill add it to the PR 👍

@varun-sundar-rabindranath
Copy link
Contributor Author

Besides the variable name, the PR looks good to me! @varun-sundar-rabindranath Do you have any performance numbers after this PR?

Hi @WoosukKwon - Added benchmark numbers to the PR description. It definitely helps V1 when max_loras < number of loras used. However V1 LoRA in this case does lag behind V0 - I think it has to do with the set_active_loras overhead in scheduling. I can handle it in a separate PR.

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 25, 2025
@WoosukKwon WoosukKwon enabled auto-merge (squash) March 25, 2025 19:44
@WoosukKwon WoosukKwon merged commit a5cfbab into vllm-project:main Mar 25, 2025
46 checks passed
wrmedford pushed a commit to wrmedford/vllm that referenced this pull request Mar 26, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Wes Medford <[email protected]>
lengrongfu pushed a commit to lengrongfu/vllm that referenced this pull request Apr 2, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants