Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 106 additions & 73 deletions backends/webgpu/runtime/ops/sdpa/Sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ static WGPUBuffer record_update_cache_dispatch(
uint32_t kv_dst_offset,
uint64_t cache_numel,
uint32_t uc_wg,
bool dynamic_pos,
bool retain_uniform,
const char* label) {
const uint32_t wgc = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(kv_numel), uc_wg, label);
Expand All @@ -274,7 +274,7 @@ static WGPUBuffer record_update_cache_dispatch(
sizeof(uc),
wgc,
uc_wg,
dynamic_pos,
retain_uniform,
"update_cache");
return ubuf;
}
Expand Down Expand Up @@ -436,7 +436,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// Dynamic input_pos: the resize hook rewrites these per step.
WGPUBuffer uc_k_buf = nullptr, uc_v_buf = nullptr, qk_buf = nullptr,
softmax_buf = nullptr, av_buf = nullptr;
size_t qk_idx = 0;
size_t qk_idx = 0, uc_k_idx = 0, uc_v_idx = 0, softmax_idx = 0, av_idx = 0;

const WGPUDevice device = graph.device();
const uint32_t uc_wg =
Expand All @@ -461,7 +461,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
kv_dst_offset,
numel(k_cache),
uc_wg,
dynamic_pos,
true,
"update_cache(K)");
uc_v_buf = record_update_cache_dispatch(
graph,
Expand All @@ -472,8 +472,10 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
kv_dst_offset,
numel(v_cache),
uc_wg,
dynamic_pos,
true,
"update_cache(V)");
uc_k_idx = graph.num_dispatches() - 2;
uc_v_idx = graph.num_dispatches() - 1;

// --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile.
{
Expand Down Expand Up @@ -501,7 +503,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
sizeof(p),
wgc,
qk_wg,
dynamic_pos,
true,
"sdpa_compute_attn_weights");
qk_buf = ubuf;
qk_idx = graph.num_dispatches() - 1;
Expand All @@ -525,9 +527,10 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
sizeof(p),
wgc,
0,
dynamic_pos,
true,
"sdpa_softmax");
softmax_buf = ubuf;
softmax_idx = graph.num_dispatches() - 1;
}

// --- Dispatch 5: AV -> out. One thread per TM x TN tile.
Expand All @@ -551,77 +554,107 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
sizeof(p),
wgc,
av_wg,
dynamic_pos,
true,
"sdpa_compute_out");
av_buf = ubuf;
av_idx = graph.num_dispatches() - 1;
}

// Per-step recompute hook; mirrors Vulkan DynamicDispatchNode.
// Per-step recompute: live S (q resize) or input_pos (SymInt); inert if
// static.
const int64_t pos_const = input_pos;
auto sdpa_resize = [q_id,
qn,
out_id,
dynamic_pos,
input_pos_id,
pos_const,
Hq,
Hkv,
D,
Cmax,
g,
scale,
qk_idx,
uc_k_idx,
uc_v_idx,
softmax_idx,
av_idx,
uc_wg,
qk_wg,
av_wg,
uc_k_buf,
uc_v_buf,
qk_buf,
softmax_buf,
av_buf](WebGPUGraph& gr) {
const int64_t s = gr.cur_dims(q_id)[qn - 3];
const int64_t pos = dynamic_pos
? static_cast<int64_t>(gr.read_symint(input_pos_id))
: pos_const;
if (s <= 0 || pos < 0) {
throw std::runtime_error("WebGPU sdpa: invalid live S or input_pos");
}
const int64_t ctx = s + pos;
if (ctx <= 0 || ctx > Cmax) {
throw std::runtime_error(
"WebGPU sdpa: context_len exceeds cache capacity");
}
const uint32_t kv_off = static_cast<uint32_t>(
static_cast<uint64_t>(pos) * static_cast<uint64_t>(Hkv) *
static_cast<uint64_t>(D));
const uint64_t aw_floats = static_cast<uint64_t>(Hq) *
static_cast<uint64_t>(s) * static_cast<uint64_t>(ctx);
if (aw_floats > UINT32_MAX) {
throw std::runtime_error("WebGPU sdpa: Hq*S*context_len exceeds uint32");
}
const uint64_t kv_numel = static_cast<uint64_t>(s) *
static_cast<uint64_t>(Hkv) * static_cast<uint64_t>(D);
const uint64_t k_cache_numel = static_cast<uint64_t>(Cmax) *
static_cast<uint64_t>(Hkv) * static_cast<uint64_t>(D);

// update_cache K/V: dispatch (kv_numel) + dst offset scale with live S/pos.
UpdateCacheParams uc =
make_update_cache_params(kv_numel, kv_off, k_cache_numel);
wgpuQueueWriteBuffer(gr.queue(), uc_k_buf, 0, &uc, sizeof(uc));
wgpuQueueWriteBuffer(gr.queue(), uc_v_buf, 0, &uc, sizeof(uc));
const uint32_t uc_wgc = utils::compute_1d_workgroup_count(
gr.device(), static_cast<uint32_t>(kv_numel), uc_wg, "uc(resize)");
gr.dispatch_at(uc_k_idx).workgroup_count_x = uc_wgc;
gr.dispatch_at(uc_v_idx).workgroup_count_x = uc_wgc;

// QK: one thread per TM x TN tile; grid = Hq*ceil(S/TM)*ceil(ctx/TN).
AttnWeightsParams qp =
make_attn_weights_params(s, Hq, Hkv, D, ctx, pos, g, scale);
wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp));
const int64_t qk_tiles =
Hq * utils::div_up(s, kSdpaTileM) * utils::div_up(ctx, kSdpaTileN);
gr.dispatch_at(qk_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
gr.device(), static_cast<uint32_t>(qk_tiles), qk_wg, "QK(resize)");

// softmax: one workgroup per (h,s) row.
SoftmaxParams sp = make_softmax_params(Hq, s, ctx);
wgpuQueueWriteBuffer(gr.queue(), softmax_buf, 0, &sp, sizeof(sp));
gr.dispatch_at(softmax_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
gr.device(), static_cast<uint32_t>(Hq * s), 1, "softmax(resize)");

// AV: one thread per TM x TN tile; grid = Hq*ceil(S/TM)*ceil(D/TN).
ComputeOutParams op = make_compute_out_params(s, Hq, Hkv, D, ctx, g);
wgpuQueueWriteBuffer(gr.queue(), av_buf, 0, &op, sizeof(op));
const int64_t av_tiles =
Hq * utils::div_up(s, kSdpaTileM) * utils::div_up(D, kSdpaTileN);
gr.dispatch_at(av_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
gr.device(), static_cast<uint32_t>(av_tiles), av_wg, "AV(resize)");

// Output attn has the same shape as q: [.., S, Hq, D].
gr.set_cur_dims(out_id, gr.cur_dims(q_id));
};
graph.add_tensor_resize_hook(q_id, sdpa_resize);
if (dynamic_pos) {
graph.add_resize_hook(
input_pos_id,
[input_pos_id,
S,
Hq,
Hkv,
D,
Cmax,
g,
scale,
qk_idx,
qk_wg,
uc_k_buf,
uc_v_buf,
qk_buf,
softmax_buf,
av_buf](WebGPUGraph& gr) {
const int32_t pos = gr.read_symint(input_pos_id);
if (pos < 0) {
throw std::runtime_error(
"WebGPU sdpa: input_pos must be non-negative");
}
const int64_t ctx = S + pos;
if (ctx <= 0 || ctx > Cmax) {
throw std::runtime_error(
"WebGPU sdpa: context_len exceeds cache capacity");
}
const uint32_t kv_off = static_cast<uint32_t>(
static_cast<uint64_t>(pos) * static_cast<uint64_t>(Hkv) *
static_cast<uint64_t>(D));
const uint64_t aw_floats = static_cast<uint64_t>(Hq) *
static_cast<uint64_t>(S) * static_cast<uint64_t>(ctx);
if (aw_floats > UINT32_MAX) {
throw std::runtime_error(
"WebGPU sdpa: Hq*S*context_len exceeds uint32 max");
}
const uint64_t kv_numel = static_cast<uint64_t>(S) *
static_cast<uint64_t>(Hkv) * static_cast<uint64_t>(D);
const uint64_t k_cache_numel = static_cast<uint64_t>(Cmax) *
static_cast<uint64_t>(Hkv) * static_cast<uint64_t>(D);

UpdateCacheParams uc =
make_update_cache_params(kv_numel, kv_off, k_cache_numel);
wgpuQueueWriteBuffer(gr.queue(), uc_k_buf, 0, &uc, sizeof(uc));
wgpuQueueWriteBuffer(gr.queue(), uc_v_buf, 0, &uc, sizeof(uc));

AttnWeightsParams qp =
make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale);
wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp));
const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) *
utils::div_up(ctx, kSdpaTileN);
const uint32_t qk_wgc = utils::compute_1d_workgroup_count(
gr.device(),
static_cast<uint32_t>(qk_tiles),
qk_wg,
"QK(resize)");
gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc;

SoftmaxParams sp = make_softmax_params(Hq, S, ctx);
wgpuQueueWriteBuffer(gr.queue(), softmax_buf, 0, &sp, sizeof(sp));

ComputeOutParams op = make_compute_out_params(S, Hq, Hkv, D, ctx, g);
wgpuQueueWriteBuffer(gr.queue(), av_buf, 0, &op, sizeof(op));
});
graph.add_resize_hook(input_pos_id, sdpa_resize);
}
}

Expand Down
Loading