19
19
import random
20
20
from collections import defaultdict
21
21
from contextlib import contextmanager
22
+ from enum import Enum , auto
22
23
from typing import Dict , List , Optional
23
24
24
25
from sglang .srt .managers .schedule_batch import Req , ScheduleBatch
@@ -104,6 +105,12 @@ def get_dfs_priority(
104
105
q .extend (last_node_to_reqs [cur_node ])
105
106
106
107
108
+ class AddReqResult (Enum ):
109
+ CONTINUE = auto () # Continue to add requests
110
+ NO_TOKEN = auto () # No token left
111
+ OTHER = auto () # Other reasons to stop adding requests
112
+
113
+
107
114
class PrefillAdder :
108
115
def __init__ (
109
116
self ,
@@ -145,17 +152,16 @@ def __init__(
145
152
]
146
153
)
147
154
148
- def no_remaining_tokens (self ):
149
- return (
150
- self .rem_total_tokens <= 0
151
- or self .rem_input_tokens <= 0
152
- or (
153
- self .rem_chunk_tokens <= 0
154
- if self .rem_chunk_tokens is not None
155
- else False
156
- )
157
- or self .cur_rem_tokens <= 0
158
- )
155
+ def budget_state (self ):
156
+ if self .rem_total_tokens <= 0 or self .cur_rem_tokens <= 0 :
157
+ return AddReqResult .NO_TOKEN
158
+
159
+ if self .rem_input_tokens <= 0 or (
160
+ self .rem_chunk_tokens is not None and self .rem_chunk_tokens <= 0
161
+ ):
162
+ return AddReqResult .OTHER
163
+
164
+ return AddReqResult .CONTINUE
159
165
160
166
def _prefill_one_req (
161
167
self , prefix_len : int , extend_input_len : int , max_new_tokens : int
@@ -239,7 +245,7 @@ def add_req_state(r, insert_sort=False):
239
245
)
240
246
bs = len (self .req_states ) - i
241
247
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0 :
242
- return False
248
+ return AddReqResult . NO_TOKEN
243
249
tokens_freed += tokens_occupied
244
250
245
251
if req .extend_input_len <= self .rem_chunk_tokens :
@@ -258,7 +264,7 @@ def add_req_state(r, insert_sort=False):
258
264
self .new_inflight_req = req
259
265
self ._prefill_one_req (0 , trunc_len , 0 )
260
266
261
- return True
267
+ return self . budget_state ()
262
268
263
269
def add_one_req (self , req : Req ):
264
270
if req .sampling_params .ignore_eos and self .tree_cache .disable :
@@ -271,14 +277,14 @@ def add_one_req(self, req: Req):
271
277
prefix_len = len (req .prefix_indices )
272
278
273
279
if total_tokens >= self .rem_total_tokens :
274
- return False
280
+ return AddReqResult . NO_TOKEN
275
281
276
282
if input_tokens > self .rem_input_tokens and len (self .can_run_list ) != 0 :
277
- return False
283
+ return AddReqResult . OTHER
278
284
279
285
with self ._lock_node (req .last_node ):
280
286
if total_tokens > self .rem_total_tokens :
281
- return False
287
+ return AddReqResult . NO_TOKEN
282
288
283
289
if (
284
290
self .rem_chunk_tokens is None
@@ -297,7 +303,7 @@ def add_one_req(self, req: Req):
297
303
# Chunked prefill
298
304
trunc_len = self .rem_chunk_tokens
299
305
if trunc_len == 0 :
300
- return False
306
+ return AddReqResult . OTHER
301
307
302
308
req .extend_input_len = trunc_len
303
309
req .fill_ids = req .fill_ids [: len (req .prefix_indices ) + trunc_len ]
@@ -306,4 +312,4 @@ def add_one_req(self, req: Req):
306
312
self .tree_cache .inc_lock_ref (req .last_node )
307
313
self ._prefill_one_req (prefix_len , trunc_len , 0 )
308
314
309
- return True and not self .no_remaining_tokens ()
315
+ return self .budget_state ()
0 commit comments