Skip to content

Commit 2dd7d0c

Browse files
authored
Revert "Fix nightly-test CI" (sgl-project#4065)
1 parent 0d4e322 commit 2dd7d0c

File tree

1 file changed

+2
-13
lines changed

1 file changed

+2
-13
lines changed

python/sglang/srt/layers/attention/flashinfer_backend.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,7 @@ def forward_extend(
427427
else:
428428
o2, s2 = prefill_wrapper_paged.forward_return_lse(
429429
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
430-
self._to_dtype(
431-
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
432-
q.dtype,
433-
),
430+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
434431
causal=False,
435432
sm_scale=layer.scaling,
436433
logits_soft_cap=layer.logit_cap,
@@ -472,9 +469,7 @@ def forward_decode(
472469

473470
o = decode_wrapper.forward(
474471
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
475-
self._to_dtype(
476-
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), q.dtype
477-
),
472+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
478473
sm_scale=layer.scaling,
479474
logits_soft_cap=layer.logit_cap,
480475
k_scale=layer.k_scale,
@@ -483,12 +478,6 @@ def forward_decode(
483478

484479
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
485480

486-
def _to_dtype(self, kv_tuple, dtype):
487-
if kv_tuple[0].dtype != dtype:
488-
return tuple(t.to(dtype) for t in kv_tuple)
489-
else:
490-
return kv_tuple
491-
492481
def _get_wrapper_idx(self, layer: RadixAttention):
493482
if self.num_wrappers == 1:
494483
return 0

0 commit comments

Comments
 (0)