@@ -108,11 +108,9 @@ def mha_forward_kernel(
108
108
# q tile has shape [block_q, head_dim_padded], head_dim_padded >= head_dim.
109
109
curr_q_slice = pl .dslice (start_q * block_q , block_q )
110
110
head_mask = (jnp .arange (head_dim_padded ) < head_dim )[None , :]
111
- q = pl .load (q_ref , ( slice ( None ), slice ( None )) , mask = head_mask , other = 0.0 )
111
+ q = plgpu .load (q_ref , mask = head_mask , other = 0.0 )
112
112
q_segment_ids = (
113
- None
114
- if segment_ids_ref is None
115
- else pl .load (segment_ids_ref , (curr_q_slice ,))
113
+ None if segment_ids_ref is None else segment_ids_ref [curr_q_slice ]
116
114
)
117
115
# In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
118
116
# (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
@@ -122,7 +120,7 @@ def body(start_k, carry):
122
120
o_prev , m_prev , l_prev = carry
123
121
curr_k_slice = pl .dslice (start_k * block_k , block_k )
124
122
125
- k = pl .load (k_ref , ( curr_k_slice , slice ( None )) , mask = head_mask , other = 0.0 )
123
+ k = plgpu .load (k_ref . at [ curr_k_slice , :] , mask = head_mask , other = 0.0 )
126
124
qk = pl .dot (q , k .T ) # [block_q, block_k]
127
125
128
126
# Scale logits to convert from base-2 to the natural log domain.
@@ -140,7 +138,7 @@ def body(start_k, carry):
140
138
if causal or segment_ids_ref is not None :
141
139
mask = None
142
140
if segment_ids_ref is not None :
143
- kv_segment_ids = pl . load ( segment_ids_ref , ( curr_k_slice ,))
141
+ kv_segment_ids = segment_ids_ref [ curr_k_slice ]
144
142
mask = segment_mask (q_segment_ids , kv_segment_ids )
145
143
if causal :
146
144
span_q = start_q * block_q + jnp .arange (block_q )
@@ -162,7 +160,7 @@ def body(start_k, carry):
162
160
l_curr = s_curr .sum (axis = - 1 )
163
161
l_next = l_prev_corr + l_curr
164
162
o_prev_corr = correction [:, None ] * o_prev
165
- v = pl .load (v_ref , ( curr_k_slice , slice ( None )) , mask = head_mask )
163
+ v = plgpu .load (v_ref . at [ curr_k_slice , :] , mask = head_mask )
166
164
o_curr = pl .dot (s_curr .astype (v .dtype ), v )
167
165
168
166
o_next = o_prev_corr + o_curr
@@ -183,8 +181,7 @@ def body(start_k, carry):
183
181
lse_ref = residual_refs [0 ]
184
182
lse_ref [...] = m_i + jnp .log2 (l_i )
185
183
# Write output to dram.
186
- pl .store (o_ref , (slice (None ), slice (o .shape [- 1 ])), o .astype (o_ref .dtype ),
187
- mask = head_mask )
184
+ plgpu .store (o_ref .at [:, : o .shape [- 1 ]], o .astype (o_ref .dtype ), mask = head_mask )
188
185
189
186
def segment_mask (
190
187
q_segment_ids : jax .Array ,
@@ -328,8 +325,8 @@ def _mha_forward(
328
325
def _preprocess_backward_kernel (out_ref , dout_ref , delta_ref , head_dim : int ):
329
326
# load
330
327
head_mask = (jnp .arange (out_ref .shape [- 1 ]) < head_dim )[None , :]
331
- o = pl .load (out_ref , ( slice ( None ), slice ( None )) , mask = head_mask , other = 0.0 )
332
- do = pl .load (dout_ref , ( slice ( None ), slice ( None )) , mask = head_mask , other = 0.0 )
328
+ o = plgpu .load (out_ref , mask = head_mask , other = 0.0 )
329
+ do = plgpu .load (dout_ref , mask = head_mask , other = 0.0 )
333
330
# compute
334
331
delta = jnp .sum (o * do , axis = 1 )
335
332
# write-back
@@ -402,20 +399,18 @@ def mha_backward_kernel(
402
399
dk = jnp .zeros ([block_kv_dkv , head_dim_padded ], dtype = jnp .float32 )
403
400
404
401
head_mask = (jnp .arange (head_dim_padded ) < head_dim )[None , :]
405
- v = pl .load (v_ref , ( curr_k_slice , slice ( None )) , mask = head_mask , other = 0.0 )
406
- k = pl .load (k_ref , ( curr_k_slice , slice ( None )) , mask = head_mask , other = 0.0 )
402
+ v = plgpu .load (v_ref . at [ curr_k_slice , :] , mask = head_mask , other = 0.0 )
403
+ k = plgpu .load (k_ref . at [ curr_k_slice , :] , mask = head_mask , other = 0.0 )
407
404
span_k = start_k * block_kv_dkv + jnp .arange (block_kv_dkv )
408
405
kv_segment_ids = (
409
- None
410
- if segment_ids_ref is None
411
- else pl .load (segment_ids_ref , (curr_k_slice ,))
406
+ None if segment_ids_ref is None else segment_ids_ref [curr_k_slice ]
412
407
)
413
408
414
409
def inner_loop_dkdv (start_q , carry ):
415
410
dv , dk = carry
416
411
curr_q_slice = pl .dslice (start_q * block_q_dkv , block_q_dkv )
417
412
418
- q = pl .load (q_ref , ( curr_q_slice , slice ( None )) , mask = head_mask , other = 0.0 )
413
+ q = plgpu .load (q_ref . at [ curr_q_slice , :] , mask = head_mask , other = 0.0 )
419
414
qk = pl .dot (q , k .T )
420
415
qk_scale = math .log2 (math .e )
421
416
if sm_scale != 1. :
@@ -425,7 +420,7 @@ def inner_loop_dkdv(start_q, carry):
425
420
if causal or segment_ids_ref is not None :
426
421
mask = None
427
422
if segment_ids_ref is not None :
428
- q_segment_ids = pl . load ( segment_ids_ref , ( curr_q_slice ,))
423
+ q_segment_ids = segment_ids_ref [ curr_q_slice ]
429
424
mask = segment_mask (q_segment_ids , kv_segment_ids )
430
425
431
426
if causal :
@@ -436,10 +431,11 @@ def inner_loop_dkdv(start_q, carry):
436
431
)
437
432
qk = jnp .where (mask , qk , DEFAULT_MASK_VALUE )
438
433
439
- lse = pl .load (lse_ref , (curr_q_slice ,))
440
- di = pl .load (delta_ref , (curr_q_slice ,))
441
- do = pl .load (do_scaled_ref , (curr_q_slice , slice (None )), mask = head_mask ,
442
- other = 0.0 )
434
+ lse = lse_ref [curr_q_slice ]
435
+ di = delta_ref [curr_q_slice ]
436
+ do = plgpu .load (
437
+ do_scaled_ref .at [curr_q_slice , :], mask = head_mask , other = 0.0
438
+ )
443
439
444
440
p = jnp .exp2 (qk - lse [:, None ])
445
441
dv = dv + pl .dot (p .astype (do .dtype ).T , do )
@@ -456,10 +452,12 @@ def inner_loop_dkdv(start_q, carry):
456
452
dv , dk = lax .fori_loop (
457
453
lower_bound , pl .cdiv (q_seq_len , block_q_dkv ), inner_loop_dkdv , (dv , dk )
458
454
)
459
- pl .store (dv_ref , (slice (None ), slice (dv .shape [- 1 ])), dv .astype (dv_ref .dtype ),
460
- mask = head_mask )
461
- pl .store (dk_ref , (slice (None ), slice (dk .shape [- 1 ])), dk .astype (dk_ref .dtype ),
462
- mask = head_mask )
455
+ plgpu .store (
456
+ dv_ref .at [:, : dv .shape [- 1 ]], dv .astype (dv_ref .dtype ), mask = head_mask
457
+ )
458
+ plgpu .store (
459
+ dk_ref .at [:, : dk .shape [- 1 ]], dk .astype (dk_ref .dtype ), mask = head_mask
460
+ )
463
461
464
462
# Scan #2: dQ
465
463
# 1. Load a block of Q of size (block_q_dq, head_dim) in SMEM.
@@ -470,21 +468,18 @@ def inner_loop_dkdv(start_q, carry):
470
468
span_q = start_q * block_q_dq + jnp .arange (block_q_dq )
471
469
dq = jnp .zeros ([block_q_dq , head_dim_padded ], dtype = jnp .float32 )
472
470
473
- q = pl .load (q_ref , ( curr_q_slice , slice ( None )) , mask = head_mask , other = 0.0 )
471
+ q = plgpu .load (q_ref . at [ curr_q_slice , :] , mask = head_mask , other = 0.0 )
474
472
q_segment_ids = (
475
- None
476
- if segment_ids_ref is None
477
- else pl .load (segment_ids_ref , (curr_q_slice ,))
473
+ None if segment_ids_ref is None else segment_ids_ref [curr_q_slice ]
478
474
)
479
- lse = pl .load (lse_ref , (curr_q_slice ,))
480
- do = pl .load (do_scaled_ref , (curr_q_slice , slice (None )), mask = head_mask ,
481
- other = 0.0 )
482
- di = pl .load (delta_ref , (curr_q_slice ,))
475
+ lse = lse_ref [curr_q_slice ]
476
+ do = plgpu .load (do_scaled_ref .at [curr_q_slice , :], mask = head_mask , other = 0.0 )
477
+ di = delta_ref [curr_q_slice ]
483
478
484
479
def inner_loop_dq (start_k , dq ):
485
480
curr_k_slice = pl .dslice (start_k * block_kv_dq , block_kv_dq )
486
- k = pl .load (k_ref , ( curr_k_slice , slice ( None )) , mask = head_mask , other = 0.0 )
487
- v = pl .load (v_ref , ( curr_k_slice , slice ( None )) , mask = head_mask , other = 0.0 )
481
+ k = plgpu .load (k_ref . at [ curr_k_slice , :] , mask = head_mask , other = 0.0 )
482
+ v = plgpu .load (v_ref . at [ curr_k_slice , :] , mask = head_mask , other = 0.0 )
488
483
489
484
qk = pl .dot (q , k .T )
490
485
qk_scale = math .log2 (math .e )
@@ -495,7 +490,7 @@ def inner_loop_dq(start_k, dq):
495
490
if causal or segment_ids_ref is not None :
496
491
mask = None
497
492
if segment_ids_ref is not None :
498
- kv_segment_ids = pl . load ( segment_ids_ref , ( curr_k_slice ,))
493
+ kv_segment_ids = segment_ids_ref [ curr_k_slice ]
499
494
mask = segment_mask (q_segment_ids , kv_segment_ids )
500
495
501
496
if causal :
@@ -523,8 +518,9 @@ def inner_loop_dq(start_k, dq):
523
518
upper_bound = pl .cdiv (kv_seq_len , block_kv_dq )
524
519
525
520
dq = lax .fori_loop (0 , upper_bound , inner_loop_dq , (dq ))
526
- pl .store (dq_ref , (slice (None ), slice (dq .shape [- 1 ])), dq .astype (dq_ref .dtype ),
527
- mask = head_mask )
521
+ plgpu .store (
522
+ dq_ref .at [:, : dq .shape [- 1 ]], dq .astype (dq_ref .dtype ), mask = head_mask
523
+ )
528
524
529
525
530
526
def _mha_backward (sm_scale : float , causal : bool , block_sizes : BlockSizes ,
0 commit comments