Skip to content

Commit 5d0ba40

Browse files
authored
Refine the add request reasons to avoid corner cases. (sgl-project#1574)
1 parent 04b262c commit 5d0ba40

File tree

2 files changed

+34
-25
lines changed

2 files changed

+34
-25
lines changed

python/sglang/srt/managers/schedule_policy.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import random
2020
from collections import defaultdict
2121
from contextlib import contextmanager
22+
from enum import Enum, auto
2223
from typing import Dict, List, Optional
2324

2425
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
@@ -104,6 +105,12 @@ def get_dfs_priority(
104105
q.extend(last_node_to_reqs[cur_node])
105106

106107

108+
class AddReqResult(Enum):
109+
CONTINUE = auto() # Continue to add requests
110+
NO_TOKEN = auto() # No token left
111+
OTHER = auto() # Other reasons to stop adding requests
112+
113+
107114
class PrefillAdder:
108115
def __init__(
109116
self,
@@ -145,17 +152,16 @@ def __init__(
145152
]
146153
)
147154

148-
def no_remaining_tokens(self):
149-
return (
150-
self.rem_total_tokens <= 0
151-
or self.rem_input_tokens <= 0
152-
or (
153-
self.rem_chunk_tokens <= 0
154-
if self.rem_chunk_tokens is not None
155-
else False
156-
)
157-
or self.cur_rem_tokens <= 0
158-
)
155+
def budget_state(self):
156+
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
157+
return AddReqResult.NO_TOKEN
158+
159+
if self.rem_input_tokens <= 0 or (
160+
self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
161+
):
162+
return AddReqResult.OTHER
163+
164+
return AddReqResult.CONTINUE
159165

160166
def _prefill_one_req(
161167
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
@@ -239,7 +245,7 @@ def add_req_state(r, insert_sort=False):
239245
)
240246
bs = len(self.req_states) - i
241247
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
242-
return False
248+
return AddReqResult.NO_TOKEN
243249
tokens_freed += tokens_occupied
244250

245251
if req.extend_input_len <= self.rem_chunk_tokens:
@@ -258,7 +264,7 @@ def add_req_state(r, insert_sort=False):
258264
self.new_inflight_req = req
259265
self._prefill_one_req(0, trunc_len, 0)
260266

261-
return True
267+
return self.budget_state()
262268

263269
def add_one_req(self, req: Req):
264270
if req.sampling_params.ignore_eos and self.tree_cache.disable:
@@ -271,14 +277,14 @@ def add_one_req(self, req: Req):
271277
prefix_len = len(req.prefix_indices)
272278

273279
if total_tokens >= self.rem_total_tokens:
274-
return False
280+
return AddReqResult.NO_TOKEN
275281

276282
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
277-
return False
283+
return AddReqResult.OTHER
278284

279285
with self._lock_node(req.last_node):
280286
if total_tokens > self.rem_total_tokens:
281-
return False
287+
return AddReqResult.NO_TOKEN
282288

283289
if (
284290
self.rem_chunk_tokens is None
@@ -297,7 +303,7 @@ def add_one_req(self, req: Req):
297303
# Chunked prefill
298304
trunc_len = self.rem_chunk_tokens
299305
if trunc_len == 0:
300-
return False
306+
return AddReqResult.OTHER
301307

302308
req.extend_input_len = trunc_len
303309
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
@@ -306,4 +312,4 @@ def add_one_req(self, req: Req):
306312
self.tree_cache.inc_lock_ref(req.last_node)
307313
self._prefill_one_req(prefix_len, trunc_len, 0)
308314

309-
return True and not self.no_remaining_tokens()
315+
return self.budget_state()

python/sglang/srt/managers/scheduler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@
5050
Req,
5151
ScheduleBatch,
5252
)
53-
from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy
53+
from sglang.srt.managers.schedule_policy import (
54+
AddReqResult,
55+
PrefillAdder,
56+
SchedulePolicy,
57+
)
5458
from sglang.srt.managers.tp_worker import TpModelWorker
5559
from sglang.srt.mem_cache.chunk_cache import ChunkCache
5660
from sglang.srt.mem_cache.radix_cache import RadixCache
@@ -493,16 +497,15 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
493497
self.batch_is_full = True
494498
break
495499

496-
if adder.no_remaining_tokens():
500+
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
497501
self.batch_is_full = True
498502
break
503+
499504
req.init_next_round_input(None if prefix_computed else self.tree_cache)
500505
res = adder.add_one_req(req)
501-
if (
502-
not res
503-
or running_bs + len(adder.can_run_list) >= self.max_running_requests
504-
):
505-
self.batch_is_full = True
506+
if res != AddReqResult.CONTINUE:
507+
if res == AddReqResult.NO_TOKEN:
508+
self.batch_is_full = True
506509
break
507510

508511
can_run_list = adder.can_run_list

0 commit comments

Comments
 (0)