Skip to content

Commit ac2dc35

Browse files
authored
support lightning_attention_decode in sgl-kernel for MiniMax-Text-01 (sgl-project#3030)
1 parent 3e032c0 commit ac2dc35

File tree

8 files changed

+588
-8
lines changed

8 files changed

+588
-8
lines changed

benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py

+69-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import triton
1010
import triton.language as tl
1111
from einops import rearrange
12+
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
1213

1314

1415
@triton.jit
@@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params):
332333
model_params["num_attention_heads"],
333334
d,
334335
d,
335-
dtype=dtype,
336336
device=device,
337337
)
338338
with torch.no_grad():
@@ -350,30 +350,64 @@ def test_lightning_attention_implementations(model_params):
350350
q = q.transpose(1, 2)
351351
k = k.transpose(1, 2)
352352
v = v.transpose(1, 2)
353+
q = q.contiguous()
354+
k = k.contiguous()
355+
v = v.contiguous()
356+
past_kv = past_kv.contiguous()
357+
slope_rate = slope_rate.contiguous()
353358

359+
# Test Triton implementation
354360
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
355361
triton_output = triton_output.transpose(1, 2).contiguous()
356362
triton_output = triton_output.view(batch_size, seq_len, -1)
357363
triton_output = model_attn.norm(triton_output)
358364
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
359365
triton_output = model_attn.out_proj(triton_output)
360366

367+
# Test SGL implementation
368+
sgl_output = torch.empty_like(v)
369+
sgl_new_kv = torch.empty_like(past_kv)
370+
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
371+
372+
sgl_output = sgl_output.transpose(1, 2).contiguous()
373+
sgl_output = sgl_output.view(batch_size, seq_len, -1)
374+
sgl_output = model_attn.norm(sgl_output)
375+
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
376+
sgl_output = model_attn.out_proj(sgl_output)
377+
378+
# Verify Triton implementation results
361379
torch.testing.assert_close(
362380
model_output,
363381
triton_output,
364382
rtol=1e-3,
365383
atol=1e-2,
366-
msg="Lightning attention implementations produce different output results",
384+
msg="Triton lightning attention implementation produces different output results",
367385
)
368386
torch.testing.assert_close(
369387
new_kv,
370388
triton_new_kv,
371389
rtol=1e-3,
372390
atol=1e-2,
373-
msg="Lightning attention implementations produce different kv results",
391+
msg="Triton lightning attention implementation produces different kv results",
374392
)
375393

376-
print("✅ Two implementations match")
394+
# Verify SGL implementation results
395+
torch.testing.assert_close(
396+
model_output,
397+
sgl_output,
398+
rtol=1e-3,
399+
atol=1e-2,
400+
msg="SGL lightning attention implementation produces different output results",
401+
)
402+
torch.testing.assert_close(
403+
new_kv,
404+
sgl_new_kv,
405+
rtol=1e-3,
406+
atol=1e-2,
407+
msg="SGL lightning attention implementation produces different kv results",
408+
)
409+
410+
print("✅ All implementations match")
377411

378412

379413
def _build_slope_tensor(n_attention_heads: int):
@@ -408,12 +442,13 @@ def get_benchmark():
408442
x_names=["batch_size", "seq_len"],
409443
x_vals=[list(_) for _ in configs],
410444
line_arg="provider",
411-
line_vals=["Original", "Triton"],
445+
line_vals=["Original", "Triton", "SGL"],
412446
line_names=[
413447
"Original PyTorch Implementation",
414448
"Triton Implementation",
449+
"SGL Implementation",
415450
],
416-
styles=[("blue", "-"), ("green", "-")],
451+
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
417452
ylabel="us",
418453
plot_name="lightning-attention-decode-performance",
419454
args={},
@@ -446,7 +481,6 @@ def benchmark(batch_size, seq_len, provider):
446481
params["num_attention_heads"],
447482
d,
448483
d,
449-
dtype=dtype,
450484
device=device,
451485
)
452486

@@ -461,7 +495,7 @@ def benchmark(batch_size, seq_len, provider):
461495
),
462496
quantiles=quantiles,
463497
)
464-
else:
498+
elif provider == "Triton":
465499

466500
def run_triton():
467501
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
@@ -483,6 +517,33 @@ def run_triton():
483517
run_triton,
484518
quantiles=quantiles,
485519
)
520+
else: # SGL
521+
522+
def run_sgl():
523+
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
524+
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
525+
qkv = qkv.view(*new_shape)
526+
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
527+
q = q.transpose(1, 2).contiguous()
528+
k = k.transpose(1, 2).contiguous()
529+
v = v.transpose(1, 2).contiguous()
530+
531+
output = torch.empty_like(v)
532+
new_kv = torch.empty_like(past_kv)
533+
sgl_lightning_attention_decode(
534+
q, k, v, past_kv, slope_rate, output, new_kv
535+
)
536+
537+
output = output.transpose(1, 2).contiguous()
538+
output = output.view(batch_size, seq_len, -1)
539+
output = model_attn.norm(output)
540+
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
541+
return model_attn.out_proj(output)
542+
543+
ms, min_ms, max_ms = triton.testing.do_bench(
544+
run_sgl,
545+
quantiles=quantiles,
546+
)
486547

487548
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
488549

0 commit comments

Comments
 (0)