9
9
import triton
10
10
import triton .language as tl
11
11
from einops import rearrange
12
+ from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
12
13
13
14
14
15
@triton .jit
@@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params):
332
333
model_params ["num_attention_heads" ],
333
334
d ,
334
335
d ,
335
- dtype = dtype ,
336
336
device = device ,
337
337
)
338
338
with torch .no_grad ():
@@ -350,30 +350,64 @@ def test_lightning_attention_implementations(model_params):
350
350
q = q .transpose (1 , 2 )
351
351
k = k .transpose (1 , 2 )
352
352
v = v .transpose (1 , 2 )
353
+ q = q .contiguous ()
354
+ k = k .contiguous ()
355
+ v = v .contiguous ()
356
+ past_kv = past_kv .contiguous ()
357
+ slope_rate = slope_rate .contiguous ()
353
358
359
+ # Test Triton implementation
354
360
triton_output , triton_new_kv = lightning_attn_decode (q , k , v , past_kv , slope_rate )
355
361
triton_output = triton_output .transpose (1 , 2 ).contiguous ()
356
362
triton_output = triton_output .view (batch_size , seq_len , - 1 )
357
363
triton_output = model_attn .norm (triton_output )
358
364
triton_output = torch .sigmoid (model_attn .output_gate (hidden_states )) * triton_output
359
365
triton_output = model_attn .out_proj (triton_output )
360
366
367
+ # Test SGL implementation
368
+ sgl_output = torch .empty_like (v )
369
+ sgl_new_kv = torch .empty_like (past_kv )
370
+ sgl_lightning_attention_decode (q , k , v , past_kv , slope_rate , sgl_output , sgl_new_kv )
371
+
372
+ sgl_output = sgl_output .transpose (1 , 2 ).contiguous ()
373
+ sgl_output = sgl_output .view (batch_size , seq_len , - 1 )
374
+ sgl_output = model_attn .norm (sgl_output )
375
+ sgl_output = torch .sigmoid (model_attn .output_gate (hidden_states )) * sgl_output
376
+ sgl_output = model_attn .out_proj (sgl_output )
377
+
378
+ # Verify Triton implementation results
361
379
torch .testing .assert_close (
362
380
model_output ,
363
381
triton_output ,
364
382
rtol = 1e-3 ,
365
383
atol = 1e-2 ,
366
- msg = "Lightning attention implementations produce different output results" ,
384
+ msg = "Triton lightning attention implementation produces different output results" ,
367
385
)
368
386
torch .testing .assert_close (
369
387
new_kv ,
370
388
triton_new_kv ,
371
389
rtol = 1e-3 ,
372
390
atol = 1e-2 ,
373
- msg = "Lightning attention implementations produce different kv results" ,
391
+ msg = "Triton lightning attention implementation produces different kv results" ,
374
392
)
375
393
376
- print ("✅ Two implementations match" )
394
+ # Verify SGL implementation results
395
+ torch .testing .assert_close (
396
+ model_output ,
397
+ sgl_output ,
398
+ rtol = 1e-3 ,
399
+ atol = 1e-2 ,
400
+ msg = "SGL lightning attention implementation produces different output results" ,
401
+ )
402
+ torch .testing .assert_close (
403
+ new_kv ,
404
+ sgl_new_kv ,
405
+ rtol = 1e-3 ,
406
+ atol = 1e-2 ,
407
+ msg = "SGL lightning attention implementation produces different kv results" ,
408
+ )
409
+
410
+ print ("✅ All implementations match" )
377
411
378
412
379
413
def _build_slope_tensor (n_attention_heads : int ):
@@ -408,12 +442,13 @@ def get_benchmark():
408
442
x_names = ["batch_size" , "seq_len" ],
409
443
x_vals = [list (_ ) for _ in configs ],
410
444
line_arg = "provider" ,
411
- line_vals = ["Original" , "Triton" ],
445
+ line_vals = ["Original" , "Triton" , "SGL" ],
412
446
line_names = [
413
447
"Original PyTorch Implementation" ,
414
448
"Triton Implementation" ,
449
+ "SGL Implementation" ,
415
450
],
416
- styles = [("blue" , "-" ), ("green" , "-" )],
451
+ styles = [("blue" , "-" ), ("green" , "-" ), ( "red" , "-" ) ],
417
452
ylabel = "us" ,
418
453
plot_name = "lightning-attention-decode-performance" ,
419
454
args = {},
@@ -446,7 +481,6 @@ def benchmark(batch_size, seq_len, provider):
446
481
params ["num_attention_heads" ],
447
482
d ,
448
483
d ,
449
- dtype = dtype ,
450
484
device = device ,
451
485
)
452
486
@@ -461,7 +495,7 @@ def benchmark(batch_size, seq_len, provider):
461
495
),
462
496
quantiles = quantiles ,
463
497
)
464
- else :
498
+ elif provider == "Triton" :
465
499
466
500
def run_triton ():
467
501
qkv = model_attn .act (model_attn .qkv_proj (hidden_states ))
@@ -483,6 +517,33 @@ def run_triton():
483
517
run_triton ,
484
518
quantiles = quantiles ,
485
519
)
520
+ else : # SGL
521
+
522
+ def run_sgl ():
523
+ qkv = model_attn .act (model_attn .qkv_proj (hidden_states ))
524
+ new_shape = qkv .size ()[:- 1 ] + (model_attn .num_heads , - 1 )
525
+ qkv = qkv .view (* new_shape )
526
+ q , k , v = torch .split (qkv , [model_attn .head_dim ] * 3 , dim = - 1 )
527
+ q = q .transpose (1 , 2 ).contiguous ()
528
+ k = k .transpose (1 , 2 ).contiguous ()
529
+ v = v .transpose (1 , 2 ).contiguous ()
530
+
531
+ output = torch .empty_like (v )
532
+ new_kv = torch .empty_like (past_kv )
533
+ sgl_lightning_attention_decode (
534
+ q , k , v , past_kv , slope_rate , output , new_kv
535
+ )
536
+
537
+ output = output .transpose (1 , 2 ).contiguous ()
538
+ output = output .view (batch_size , seq_len , - 1 )
539
+ output = model_attn .norm (output )
540
+ output = torch .sigmoid (model_attn .output_gate (hidden_states )) * output
541
+ return model_attn .out_proj (output )
542
+
543
+ ms , min_ms , max_ms = triton .testing .do_bench (
544
+ run_sgl ,
545
+ quantiles = quantiles ,
546
+ )
486
547
487
548
return 1000 * ms , 1000 * max_ms , 1000 * min_ms
488
549
0 commit comments