Skip to content

Commit 9c76aa5

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:triton] Migrated example kernels to plgpu.{load,store}
PiperOrigin-RevId: 786856202
1 parent 17e5b94 commit 9c76aa5

File tree

6 files changed

+140
-109
lines changed

6 files changed

+140
-109
lines changed

jax/experimental/pallas/ops/gpu/attention.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,9 @@ def mha_forward_kernel(
108108
# q tile has shape [block_q, head_dim_padded], head_dim_padded >= head_dim.
109109
curr_q_slice = pl.dslice(start_q * block_q, block_q)
110110
head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :]
111-
q = pl.load(q_ref, (slice(None), slice(None)), mask=head_mask, other=0.0)
111+
q = plgpu.load(q_ref, mask=head_mask, other=0.0)
112112
q_segment_ids = (
113-
None
114-
if segment_ids_ref is None
115-
else pl.load(segment_ids_ref, (curr_q_slice,))
113+
None if segment_ids_ref is None else segment_ids_ref[curr_q_slice]
116114
)
117115
# In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
118116
# (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
@@ -122,7 +120,7 @@ def body(start_k, carry):
122120
o_prev, m_prev, l_prev = carry
123121
curr_k_slice = pl.dslice(start_k * block_k, block_k)
124122

125-
k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0)
123+
k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
126124
qk = pl.dot(q, k.T) # [block_q, block_k]
127125

128126
# Scale logits to convert from base-2 to the natural log domain.
@@ -140,7 +138,7 @@ def body(start_k, carry):
140138
if causal or segment_ids_ref is not None:
141139
mask = None
142140
if segment_ids_ref is not None:
143-
kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,))
141+
kv_segment_ids = segment_ids_ref[curr_k_slice]
144142
mask = segment_mask(q_segment_ids, kv_segment_ids)
145143
if causal:
146144
span_q = start_q * block_q + jnp.arange(block_q)
@@ -162,7 +160,7 @@ def body(start_k, carry):
162160
l_curr = s_curr.sum(axis=-1)
163161
l_next = l_prev_corr + l_curr
164162
o_prev_corr = correction[:, None] * o_prev
165-
v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask)
163+
v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask)
166164
o_curr = pl.dot(s_curr.astype(v.dtype), v)
167165

168166
o_next = o_prev_corr + o_curr
@@ -183,8 +181,7 @@ def body(start_k, carry):
183181
lse_ref = residual_refs[0]
184182
lse_ref[...] = m_i + jnp.log2(l_i)
185183
# Write output to dram.
186-
pl.store(o_ref, (slice(None), slice(o.shape[-1])), o.astype(o_ref.dtype),
187-
mask=head_mask)
184+
plgpu.store(o_ref.at[:, : o.shape[-1]], o.astype(o_ref.dtype), mask=head_mask)
188185

189186
def segment_mask(
190187
q_segment_ids: jax.Array,
@@ -328,8 +325,8 @@ def _mha_forward(
328325
def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, head_dim: int):
329326
# load
330327
head_mask = (jnp.arange(out_ref.shape[-1]) < head_dim)[None, :]
331-
o = pl.load(out_ref, (slice(None), slice(None)), mask=head_mask, other=0.0)
332-
do = pl.load(dout_ref, (slice(None), slice(None)), mask=head_mask, other=0.0)
328+
o = plgpu.load(out_ref, mask=head_mask, other=0.0)
329+
do = plgpu.load(dout_ref, mask=head_mask, other=0.0)
333330
# compute
334331
delta = jnp.sum(o * do, axis=1)
335332
# write-back
@@ -402,20 +399,18 @@ def mha_backward_kernel(
402399
dk = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32)
403400

404401
head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :]
405-
v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0)
406-
k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0)
402+
v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
403+
k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
407404
span_k = start_k * block_kv_dkv + jnp.arange(block_kv_dkv)
408405
kv_segment_ids = (
409-
None
410-
if segment_ids_ref is None
411-
else pl.load(segment_ids_ref, (curr_k_slice,))
406+
None if segment_ids_ref is None else segment_ids_ref[curr_k_slice]
412407
)
413408

414409
def inner_loop_dkdv(start_q, carry):
415410
dv, dk = carry
416411
curr_q_slice = pl.dslice(start_q * block_q_dkv, block_q_dkv)
417412

418-
q = pl.load(q_ref, (curr_q_slice, slice(None)), mask=head_mask, other=0.0)
413+
q = plgpu.load(q_ref.at[curr_q_slice, :], mask=head_mask, other=0.0)
419414
qk = pl.dot(q, k.T)
420415
qk_scale = math.log2(math.e)
421416
if sm_scale != 1.:
@@ -425,7 +420,7 @@ def inner_loop_dkdv(start_q, carry):
425420
if causal or segment_ids_ref is not None:
426421
mask = None
427422
if segment_ids_ref is not None:
428-
q_segment_ids = pl.load(segment_ids_ref, (curr_q_slice,))
423+
q_segment_ids = segment_ids_ref[curr_q_slice]
429424
mask = segment_mask(q_segment_ids, kv_segment_ids)
430425

431426
if causal:
@@ -436,10 +431,11 @@ def inner_loop_dkdv(start_q, carry):
436431
)
437432
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
438433

439-
lse = pl.load(lse_ref, (curr_q_slice,))
440-
di = pl.load(delta_ref, (curr_q_slice,))
441-
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)), mask=head_mask,
442-
other=0.0)
434+
lse = lse_ref[curr_q_slice]
435+
di = delta_ref[curr_q_slice]
436+
do = plgpu.load(
437+
do_scaled_ref.at[curr_q_slice, :], mask=head_mask, other=0.0
438+
)
443439

444440
p = jnp.exp2(qk - lse[:, None])
445441
dv = dv + pl.dot(p.astype(do.dtype).T, do)
@@ -456,10 +452,12 @@ def inner_loop_dkdv(start_q, carry):
456452
dv, dk = lax.fori_loop(
457453
lower_bound, pl.cdiv(q_seq_len, block_q_dkv), inner_loop_dkdv, (dv, dk)
458454
)
459-
pl.store(dv_ref, (slice(None), slice(dv.shape[-1])), dv.astype(dv_ref.dtype),
460-
mask=head_mask)
461-
pl.store(dk_ref, (slice(None), slice(dk.shape[-1])), dk.astype(dk_ref.dtype),
462-
mask=head_mask)
455+
plgpu.store(
456+
dv_ref.at[:, : dv.shape[-1]], dv.astype(dv_ref.dtype), mask=head_mask
457+
)
458+
plgpu.store(
459+
dk_ref.at[:, : dk.shape[-1]], dk.astype(dk_ref.dtype), mask=head_mask
460+
)
463461

464462
# Scan #2: dQ
465463
# 1. Load a block of Q of size (block_q_dq, head_dim) in SMEM.
@@ -470,21 +468,18 @@ def inner_loop_dkdv(start_q, carry):
470468
span_q = start_q * block_q_dq + jnp.arange(block_q_dq)
471469
dq = jnp.zeros([block_q_dq, head_dim_padded], dtype=jnp.float32)
472470

473-
q = pl.load(q_ref, (curr_q_slice, slice(None)), mask=head_mask, other=0.0)
471+
q = plgpu.load(q_ref.at[curr_q_slice, :], mask=head_mask, other=0.0)
474472
q_segment_ids = (
475-
None
476-
if segment_ids_ref is None
477-
else pl.load(segment_ids_ref, (curr_q_slice,))
473+
None if segment_ids_ref is None else segment_ids_ref[curr_q_slice]
478474
)
479-
lse = pl.load(lse_ref, (curr_q_slice,))
480-
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)), mask=head_mask,
481-
other=0.0)
482-
di = pl.load(delta_ref, (curr_q_slice,))
475+
lse = lse_ref[curr_q_slice]
476+
do = plgpu.load(do_scaled_ref.at[curr_q_slice, :], mask=head_mask, other=0.0)
477+
di = delta_ref[curr_q_slice]
483478

484479
def inner_loop_dq(start_k, dq):
485480
curr_k_slice = pl.dslice(start_k * block_kv_dq, block_kv_dq)
486-
k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0)
487-
v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0)
481+
k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
482+
v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
488483

489484
qk = pl.dot(q, k.T)
490485
qk_scale = math.log2(math.e)
@@ -495,7 +490,7 @@ def inner_loop_dq(start_k, dq):
495490
if causal or segment_ids_ref is not None:
496491
mask = None
497492
if segment_ids_ref is not None:
498-
kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,))
493+
kv_segment_ids = segment_ids_ref[curr_k_slice]
499494
mask = segment_mask(q_segment_ids, kv_segment_ids)
500495

501496
if causal:
@@ -523,8 +518,9 @@ def inner_loop_dq(start_k, dq):
523518
upper_bound = pl.cdiv(kv_seq_len, block_kv_dq)
524519

525520
dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq))
526-
pl.store(dq_ref, (slice(None), slice(dq.shape[-1])), dq.astype(dq_ref.dtype),
527-
mask=head_mask)
521+
plgpu.store(
522+
dq_ref.at[:, : dq.shape[-1]], dq.astype(dq_ref.dtype), mask=head_mask
523+
)
528524

529525

530526
def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,

jax/experimental/pallas/ops/gpu/decode_attention.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _compute(start_idx, kv_seq_len, o, m_i, l_i):
5050
# Load q: it will stay in L1 throughout. Indices form a matrix because we
5151
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
5252
# q tile has shape [block_h, head_dim].
53-
q = pl.load(q_ref, (q_slice, pl.ds(None)), mask=q_mask)
53+
q = plgpu.load(q_ref.at[q_slice, :], mask=q_mask)
5454

5555
def _dot(a, b):
5656
# if a.shape[0] == 1:
@@ -66,7 +66,7 @@ def body(start_k, carry):
6666
o_prev, m_prev, l_prev = carry
6767
curr_k_slice = pl.ds(start_k * block_k, block_k)
6868

69-
k = pl.load(k_ref, (curr_k_slice, slice(None)))
69+
k = k_ref[curr_k_slice, :]
7070
qk = _dot(q, k.T) # [block_h, block_k]
7171
if sm_scale != 1.0:
7272
qk *= sm_scale # [block_h, block_k]
@@ -86,7 +86,7 @@ def body(start_k, carry):
8686
) # Use m_next instead of m_curr to avoid a correction on l_curr
8787
l_curr = s_curr.sum(axis=-1)
8888
l_next = l_prev_corr + l_curr
89-
v = pl.load(v_ref, (curr_k_slice, slice(None)))
89+
v = v_ref[curr_k_slice, :]
9090
o_curr = _dot(s_curr.astype(v.dtype), v)
9191

9292
# flash2 unscaled_o
@@ -106,10 +106,10 @@ def body(start_k, carry):
106106

107107
start_idx = split_k_seq_len * prog_j
108108
if start_idx_ref is not None:
109-
start_idx = jnp.maximum(start_idx, pl.load(start_idx_ref, ()))
109+
start_idx = jnp.maximum(start_idx, start_idx_ref[()])
110110
kv_seq_len = (prog_j + 1) * split_k_seq_len # lower bound on actual k_seq_len
111111
if kv_seq_len_ref is not None:
112-
kv_seq_len = jnp.minimum(kv_seq_len, pl.load(kv_seq_len_ref, ()))
112+
kv_seq_len = jnp.minimum(kv_seq_len, kv_seq_len_ref[()])
113113

114114
if start_idx_ref is None and kv_seq_len is None:
115115
o, m_i, l_i = _compute(start_idx, kv_seq_len, o, m_i, l_i)
@@ -122,10 +122,10 @@ def body(start_k, carry):
122122
if residual_refs:
123123
l_ref, m_ref = residual_refs
124124
vec_q_mask = q_mask.reshape(-1) if q_mask is not None else None
125-
pl.store(l_ref, q_slice, l_i, mask=vec_q_mask)
126-
pl.store(m_ref, q_slice, m_i, mask=vec_q_mask)
125+
plgpu.store(l_ref.at[q_slice], l_i, mask=vec_q_mask)
126+
plgpu.store(m_ref.at[q_slice], m_i, mask=vec_q_mask)
127127
o = o.astype(o_ref.dtype)
128-
pl.store(o_ref, (q_slice, pl.ds(None)), o, mask=q_mask)
128+
plgpu.store(o_ref.at[q_slice, :], o, mask=q_mask)
129129

130130

131131
def decode_attn_unbatched(

jax/experimental/pallas/ops/gpu/layer_norm.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,19 @@ def layer_norm_forward_kernel(
3535
def mean_body(i, acc_ref):
3636
col_idx = i * block_size + jnp.arange(block_size)
3737
mask = col_idx < n_col
38-
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
39-
eviction_policy="evict_last").astype(jnp.float32)
38+
a = plgpu.load(
39+
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
40+
).astype(jnp.float32)
4041
acc_ref[:] += a
4142
mean = for_loop(pl.cdiv(n_col, block_size), mean_body,
4243
jnp.zeros(block_size)).sum() / n_col
4344

4445
def var_body(i, acc_ref):
4546
col_idx = i * block_size + jnp.arange(block_size)
4647
mask = col_idx < n_col
47-
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
48-
eviction_policy="evict_last").astype(jnp.float32)
48+
a = plgpu.load(
49+
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
50+
).astype(jnp.float32)
4951
a = jnp.where(mask, a - mean, 0.)
5052
acc_ref[:] += a * a
5153
var = for_loop(pl.cdiv(n_col, block_size), var_body,
@@ -59,12 +61,13 @@ def var_body(i, acc_ref):
5961
def body(i, _):
6062
col_idx = i * block_size + jnp.arange(block_size)
6163
mask = col_idx < n_col
62-
weight = pl.load(weight_ref, (col_idx,), mask=mask)
63-
bias = pl.load(bias_ref, (col_idx,), mask=mask)
64-
x = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
65-
eviction_policy="evict_first").astype(jnp.float32)
64+
weight = plgpu.load(weight_ref.at[col_idx], mask=mask)
65+
bias = plgpu.load(bias_ref.at[col_idx], mask=mask)
66+
x = plgpu.load(
67+
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_first"
68+
).astype(jnp.float32)
6669
out = (x - mean) * rstd * weight + bias
67-
pl.store(o_ref, (col_idx,), out.astype(o_ref.dtype), mask=mask)
70+
plgpu.store(o_ref.at[col_idx], out.astype(o_ref.dtype), mask=mask)
6871
for_loop(pl.cdiv(n_col, block_size), body, ())
6972

7073

@@ -119,12 +122,18 @@ def layer_norm_backward_kernel_dx(
119122
def mean_body(i, acc_ref):
120123
col_idx = i * block_size + jnp.arange(block_size)
121124
mask = col_idx < n_col
122-
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
123-
eviction_policy="evict_last").astype(jnp.float32)
124-
dout = pl.load(do_ref, (col_idx,), mask=mask, other=0.,
125-
eviction_policy="evict_last").astype(jnp.float32)
126-
weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0.,
127-
eviction_policy="evict_last").astype(jnp.float32)
125+
a = plgpu.load(
126+
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
127+
).astype(jnp.float32)
128+
dout = plgpu.load(
129+
do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
130+
).astype(jnp.float32)
131+
weight = plgpu.load(
132+
weight_ref.at[col_idx],
133+
mask=mask,
134+
other=0.0,
135+
eviction_policy="evict_last",
136+
).astype(jnp.float32)
128137
a_hat = (a - mean_ref[...]) * rstd_ref[...]
129138
wdout = weight * dout
130139
mean1_acc_ref, mean2_acc_ref = acc_ref
@@ -139,12 +148,18 @@ def mean_body(i, acc_ref):
139148
def dx_body(i, acc_ref):
140149
col_idx = i * block_size + jnp.arange(block_size)
141150
mask = col_idx < n_col
142-
a = pl.load(x_ref, (col_idx,), mask=mask, other=0.,
143-
eviction_policy="evict_last").astype(jnp.float32)
144-
dout = pl.load(do_ref, (col_idx,), mask=mask, other=0.,
145-
eviction_policy="evict_last").astype(jnp.float32)
146-
weight = pl.load(weight_ref, (col_idx,), mask=mask, other=0.,
147-
eviction_policy="evict_last").astype(jnp.float32)
151+
a = plgpu.load(
152+
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
153+
).astype(jnp.float32)
154+
dout = plgpu.load(
155+
do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
156+
).astype(jnp.float32)
157+
weight = plgpu.load(
158+
weight_ref.at[col_idx],
159+
mask=mask,
160+
other=0.0,
161+
eviction_policy="evict_last",
162+
).astype(jnp.float32)
148163
a_hat = (a - mean_ref[...]) * rstd_ref[...]
149164
wdout = weight * dout
150165
da = (wdout - (a_hat * mean1 + mean2)) * rstd_ref[...]
@@ -168,21 +183,25 @@ def body(i, acc_ref):
168183
row_idx = i * block_m + jnp.arange(block_m)
169184
row_mask = row_idx < m
170185
mask = row_mask[:, None] & col_mask[None, :]
171-
a = pl.load(
172-
x_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
186+
a = plgpu.load(
187+
x_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0
173188
).astype(jnp.float32)
174-
dout = pl.load(
175-
do_ref, (row_idx[:, None], col_idx[None]), mask=mask, other=0.0
189+
dout = plgpu.load(
190+
do_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0
176191
).astype(jnp.float32)
177-
mean = pl.load(mean_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
178-
rstd = pl.load(rstd_ref, (row_idx,), mask=row_mask, other=0.).astype(jnp.float32)
192+
mean = plgpu.load(mean_ref.at[row_idx], mask=row_mask, other=0.0).astype(
193+
jnp.float32
194+
)
195+
rstd = plgpu.load(rstd_ref.at[row_idx], mask=row_mask, other=0.0).astype(
196+
jnp.float32
197+
)
179198
a_hat = (a - mean[:, None]) * rstd[:, None]
180199
dw_acc_ref, db_acc_ref = acc_ref
181200
dw_acc_ref[:] += (dout * a_hat).sum(axis=0)
182201
db_acc_ref[:] += dout.sum(axis=0)
183202
dw_acc, db_acc = for_loop(pl.cdiv(m, block_m), body, (jnp.zeros(block_n), jnp.zeros(block_n)))
184-
pl.store(dw_ref, (col_idx,), dw_acc.astype(dw_ref.dtype), mask=col_mask)
185-
pl.store(db_ref, (col_idx,), db_acc.astype(db_ref.dtype), mask=col_mask)
203+
plgpu.store(dw_ref.at[col_idx], dw_acc.astype(dw_ref.dtype), mask=col_mask)
204+
plgpu.store(db_ref.at[col_idx], db_acc.astype(db_ref.dtype), mask=col_mask)
186205

187206

188207
def layer_norm_backward(

jax/experimental/pallas/ops/gpu/paged_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def paged_attention_kernel(
5454

5555
def _compute(start_page_idx, end_page_idx, o, m_i, l_i):
5656
q_slice = pl.ds(0, block_h)
57-
q = pl.load(q_ref, (q_slice, slice(None)))
57+
q = q_ref[q_slice, :]
5858

5959
# Loop over blocks of pages to process a entire page sequence partition.
6060
# Grid loops over q blocks over num_heads.
@@ -64,7 +64,7 @@ def body(start_k, carry):
6464
block_tables_slice = pl.ds(
6565
start_k * pages_per_compute_block, pages_per_compute_block
6666
)
67-
block_tables = pl.load(block_tables_ref, block_tables_slice)
67+
block_tables = block_tables_ref[block_tables_slice]
6868
k = k_pages_ref[block_tables].reshape(block_k, head_dim)
6969
v = v_pages_ref[block_tables].reshape(block_k, head_dim)
7070
if k_scales_pages_ref is not None:

0 commit comments

Comments
 (0)