Skip to content

Commit 04b262c

Browse files
Ying1123hnyls2002
andauthored
[Fix] Fix major performance bug in certain cases (sgl-project#1563)
Co-authored-by: hnyls2002 <[email protected]>
1 parent 2432ad4 commit 04b262c

File tree

5 files changed

+50
-18
lines changed

5 files changed

+50
-18
lines changed

.github/workflows/pr-test.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ jobs:
130130
cd test/srt
131131
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default
132132
133+
- name: Benchmark Offline Throughput (Non-streaming, small batch size)
134+
timeout-minutes: 10
135+
run: |
136+
cd test/srt
137+
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
138+
133139
performance-test-1-gpu-part-2:
134140
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
135141
runs-on: 1-gpu-runner

python/sglang/bench_serving.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -845,13 +845,15 @@ def run_benchmark(args_: argparse.Namespace):
845845
tokenizer = get_tokenizer(tokenizer_id)
846846

847847
if args.dataset_name == "sharegpt":
848+
assert args.random_input_len is None and args.random_output_len is None
848849
input_requests = sample_sharegpt_requests(
849850
dataset_path=args.dataset_path,
850851
num_requests=args.num_prompts,
851852
tokenizer=tokenizer,
852853
fixed_output_len=args.sharegpt_output_len,
853854
)
854855
elif args.dataset_name == "random":
856+
assert args.random_input_len is not None and args.random_output_len is not None
855857
input_requests = sample_random_requests(
856858
input_len=args.random_input_len,
857859
output_len=args.random_output_len,
@@ -964,13 +966,11 @@ def set_ulimit(target_soft_limit=65535):
964966
parser.add_argument(
965967
"--random-input-len",
966968
type=int,
967-
default=1024,
968969
help="Number of input tokens per request, used only for random dataset.",
969970
)
970971
parser.add_argument(
971972
"--random-output-len",
972973
type=int,
973-
default=128,
974974
help="Number of output tokens per request, used only for random dataset.",
975975
)
976976
parser.add_argument(

python/sglang/srt/managers/scheduler.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def __init__(
222222
)
223223
self.new_token_ratio = self.min_new_token_ratio
224224
self.new_token_ratio_decay = global_config.new_token_ratio_decay
225-
self.do_not_get_new_batch = False
225+
self.batch_is_full = False
226226

227227
def event_loop(self):
228228
while True:
@@ -261,12 +261,10 @@ def process_requests(self, recv_reqs: List):
261261
for recv_req in recv_reqs:
262262
if isinstance(recv_req, TokenizedGenerateReqInput):
263263
self.handle_generate_request(recv_req)
264-
self.do_not_get_new_batch = False
265264
elif isinstance(
266265
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
267266
):
268267
self.handle_embedding_request(recv_req)
269-
self.do_not_get_new_batch = False
270268
elif isinstance(recv_req, FlushCacheReq):
271269
self.flush_cache()
272270
elif isinstance(recv_req, AbortReq):
@@ -279,11 +277,12 @@ def process_requests(self, recv_reqs: List):
279277

280278
@torch.inference_mode()
281279
def forward_step(self):
282-
if self.do_not_get_new_batch and self.current_inflight_req is None:
280+
if (
281+
self.batch_is_full or len(self.waiting_queue) == 0
282+
) and self.current_inflight_req is None:
283283
new_batch = None
284284
else:
285285
new_batch = self.get_new_prefill_batch()
286-
self.do_not_get_new_batch = False
287286

288287
if new_batch is not None:
289288
# Run a new prefill batch
@@ -447,6 +446,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
447446
len(self.running_batch.reqs) if self.running_batch is not None else 0
448447
)
449448
if running_bs >= self.max_running_requests:
449+
self.batch_is_full = True
450450
return None
451451

452452
# Get priority queue
@@ -490,16 +490,19 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
490490
)
491491
> self.max_loras_per_batch
492492
):
493+
self.batch_is_full = True
493494
break
494495

495496
if adder.no_remaining_tokens():
497+
self.batch_is_full = True
496498
break
497499
req.init_next_round_input(None if prefix_computed else self.tree_cache)
498500
res = adder.add_one_req(req)
499501
if (
500502
not res
501503
or running_bs + len(adder.can_run_list) >= self.max_running_requests
502504
):
505+
self.batch_is_full = True
503506
break
504507

505508
can_run_list = adder.can_run_list
@@ -810,9 +813,6 @@ def forward_decode_batch(self, batch: ScheduleBatch):
810813
if req.top_logprobs_num > 0:
811814
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
812815

813-
if not has_finished:
814-
self.do_not_get_new_batch = True
815-
816816
self.handle_finished_requests(batch)
817817

818818
def handle_finished_requests(self, batch: ScheduleBatch):
@@ -833,6 +833,8 @@ def handle_finished_requests(self, batch: ScheduleBatch):
833833
for i, req in enumerate(batch.reqs):
834834
if not req.finished() and req is not self.current_inflight_req:
835835
unfinished_indices.append(i)
836+
else:
837+
self.batch_is_full = False
836838

837839
if req.finished() or (
838840
req.stream

python/sglang/test/test_utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,16 @@ def get_similarities(vec1, vec2):
514514
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
515515

516516

517-
def run_bench_serving(model, num_prompts, request_rate, other_server_args):
517+
def run_bench_serving(
518+
model,
519+
num_prompts,
520+
request_rate,
521+
other_server_args,
522+
dataset_name="random",
523+
random_input_len=4096,
524+
random_output_len=2048,
525+
disable_stream=False,
526+
):
518527
# Launch the server
519528
base_url = DEFAULT_URL_FOR_TEST
520529
process = popen_launch_server(
@@ -530,21 +539,21 @@ def run_bench_serving(model, num_prompts, request_rate, other_server_args):
530539
base_url=base_url,
531540
host=None,
532541
port=None,
533-
dataset_name="random",
542+
dataset_name=dataset_name,
534543
dataset_path="",
535544
model=None,
536545
tokenizer=None,
537546
num_prompts=num_prompts,
538547
sharegpt_output_len=None,
539-
random_input_len=4096,
540-
random_output_len=2048,
548+
random_input_len=random_input_len,
549+
random_output_len=random_output_len,
541550
random_range_ratio=0.0,
542551
request_rate=request_rate,
543552
multi=None,
544553
seed=0,
545554
output_file=None,
546555
disable_tqdm=False,
547-
disable_stream=False,
556+
disable_stream=disable_stream,
548557
disable_ignore_eos=False,
549558
extra_request_body=None,
550559
)

test/srt/test_bench_serving.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,22 @@ def test_offline_throughput_default(self):
2020
)
2121

2222
if is_in_ci():
23-
assert res["output_throughput"] > 2600
23+
assert res["output_throughput"] > 2830
24+
25+
def test_offline_throughput_non_stream_small_batch_size(self):
26+
res = run_bench_serving(
27+
model=DEFAULT_MODEL_NAME_FOR_TEST,
28+
num_prompts=200,
29+
request_rate=float("inf"),
30+
dataset_name="sharegpt",
31+
random_input_len=None,
32+
random_output_len=None,
33+
disable_stream=True,
34+
other_server_args=["--max-running-requests", "10"],
35+
)
36+
37+
if is_in_ci():
38+
assert res["output_throughput"] > 1000
2439

2540
def test_offline_throughput_without_radix_cache(self):
2641
res = run_bench_serving(
@@ -31,7 +46,7 @@ def test_offline_throughput_without_radix_cache(self):
3146
)
3247

3348
if is_in_ci():
34-
assert res["output_throughput"] > 2800
49+
assert res["output_throughput"] > 2880
3550

3651
def test_offline_throughput_without_chunked_prefill(self):
3752
res = run_bench_serving(
@@ -58,7 +73,7 @@ def test_offline_throughput_with_triton_attention_backend(self):
5873
)
5974

6075
if is_in_ci():
61-
assert res["output_throughput"] > 2600
76+
assert res["output_throughput"] > 2930
6277

6378
def test_offline_throughput_default_fp8(self):
6479
res = run_bench_serving(

0 commit comments

Comments
 (0)