@@ -427,10 +427,7 @@ def forward_extend(
427
427
else :
428
428
o2 , s2 = prefill_wrapper_paged .forward_return_lse (
429
429
q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
430
- self ._to_dtype (
431
- forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id ),
432
- q .dtype ,
433
- ),
430
+ forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id ),
434
431
causal = False ,
435
432
sm_scale = layer .scaling ,
436
433
logits_soft_cap = layer .logit_cap ,
@@ -472,9 +469,7 @@ def forward_decode(
472
469
473
470
o = decode_wrapper .forward (
474
471
q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
475
- self ._to_dtype (
476
- forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id ), q .dtype
477
- ),
472
+ forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id ),
478
473
sm_scale = layer .scaling ,
479
474
logits_soft_cap = layer .logit_cap ,
480
475
k_scale = layer .k_scale ,
@@ -483,12 +478,6 @@ def forward_decode(
483
478
484
479
return o .view (- 1 , layer .tp_q_head_num * layer .head_dim )
485
480
486
- def _to_dtype (self , kv_tuple , dtype ):
487
- if kv_tuple [0 ].dtype != dtype :
488
- return tuple (t .to (dtype ) for t in kv_tuple )
489
- else :
490
- return kv_tuple
491
-
492
481
def _get_wrapper_idx (self , layer : RadixAttention ):
493
482
if self .num_wrappers == 1 :
494
483
return 0
0 commit comments