diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 7e3cb32de05..527e9cb57a8 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -161,6 +161,9 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST) add_webgpu_native_test( webgpu_update_cache_test test/native/test_update_cache.cpp ) + add_webgpu_native_test( + webgpu_dynamic_shape_test test/native/test_dynamic_shape.cpp + ) # Manifest-driven op-test framework: a generic gtest driver (webgpu_op_test) + # its device-free util unit test. GTest needs -DEXECUTORCH_BUILD_TESTS=ON. diff --git a/backends/webgpu/test/native/test_dynamic_shape.cpp b/backends/webgpu/test/native/test_dynamic_shape.cpp new file mode 100644 index 00000000000..f97b87a00fe --- /dev/null +++ b/backends/webgpu/test/native/test_dynamic_shape.cpp @@ -0,0 +1,544 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Native test for dynamic tensor shapes (Option 2). One graph is built at the +// upper-bound seq-len MAXS and run at several live S; the output must match the +// torch golden at each S (allocate-at-max + per-op resize hooks + output-EValue +// resize). Cases: +// A dyn_rms at S=MAXS -> golden (static-equivalent) +// B dyn_rms at S < MAXS (64, 8, 1) -> golden (resize shrinks +// dispatch) C ONE loaded graph reused across S -> all golden (buffers +// never moved +// => bind groups stayed valid) +// D static_rms (no dynamic dim) -> golden (static path +// unchanged) F dyn_rms_chain (rms(rms(x))) at 3 S -> golden (resize +// CASCADE, DD-4) +// G rms+residual H rms*x I dyn_linear J sdpa_dyn K emb_dyn L rope_dyn +// M dyn_sigmoid N dyn_select (select_copy(0,-1), dynamic S) +// .pte + goldens from test/ops/dynamic_shape/test_dynamic_shape_export.py. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::webgpu; +using namespace executorch::extension; +using namespace executorch::runtime; + +namespace { + +constexpr int kHidden = 64; + +std::vector read_bin(const std::string& path) { + std::ifstream f(path, std::ios::binary | std::ios::ate); + if (!f) { + return {}; + } + const std::streamsize n = f.tellg(); + if (n < 0) { + return {}; + } + f.seekg(0); + std::vector v(static_cast(n) / sizeof(float)); + f.read(reinterpret_cast(v.data()), n); + return v; +} + +float max_err(const std::vector& a, const std::vector& b) { + if (a.size() != b.size() || a.empty()) { + return 1e30f; + } + float m = 0.0f; + for (size_t i = 0; i < a.size(); i++) { + m = std::fmax(m, std::fabs(a[i] - b[i])); + } + return m; +} + +// Run one forward of a [1,1,S,kHidden] input through `m`; return the output. +std::vector +run_s(Module& m, const std::string& dir, const std::string& prefix, int s) { + auto input = + read_bin(dir + "/" + prefix + ".S" + std::to_string(s) + ".input.bin"); + if (input.empty()) { + printf(" MISSING input %s.S%d\n", prefix.c_str(), s); + return {}; + } + if (input.size() != static_cast(s) * kHidden) { + printf(" WRONG input size %s.S%d\n", prefix.c_str(), s); + return {}; + } + auto t = make_tensor_ptr({1, 1, s, kHidden}, std::move(input)); + auto r = m.forward({EValue(t)}); + if (!r.ok() || r.get().empty() || !r.get()[0].isTensor()) { + printf(" forward FAILED (S=%d, err=%d)\n", s, r.ok() ? 0 : (int)r.error()); + return {}; + } + const auto& out = r.get()[0].toTensor(); + const float* d = out.const_data_ptr(); + const size_t numel = static_cast(s) * kHidden; + // Output EValue must have been resized to the live shape. + if (out.numel() != static_cast(numel)) { + printf( + " WRONG output numel: got %zd want %zu (S=%d)\n", + (ssize_t)out.numel(), + numel, + s); + return {}; + } + return std::vector(d, d + numel); +} + +bool check_s( + Module& m, + const std::string& dir, + const std::string& prefix, + int s, + bool& ok) { + auto got = run_s(m, dir, prefix, s); + auto golden = + read_bin(dir + "/" + prefix + ".S" + std::to_string(s) + ".golden.bin"); + float e = max_err(got, golden); + bool pass = !got.empty() && e < 1e-3f; + printf( + " %s S=%-3d max_err=%e -> %s\n", + prefix.c_str(), + s, + e, + pass ? "PASS" : "FAIL"); + if (!pass) { + printf(" got.size=%zu golden.size=%zu\n", got.size(), golden.size()); + for (size_t i = 0; i < 4 && i < got.size() && i < golden.size(); i++) { + printf(" [%zu] got=%.6f golden=%.6f\n", i, got[i], golden[i]); + } + } + ok = ok && pass; + return pass; +} + +// Dynamic quantized linear: input [M, lin_k] -> output [M, lin_n]. +constexpr int kLinK = 64; +constexpr int kLinN = 128; +void check_linear(const std::string& dir, int m_rows, bool& ok) { + Module m(dir + "/dyn_linear.pte"); + if (m.load_forward() != Error::Ok) { + printf(" FAIL load dyn_linear.pte\n"); + ok = false; + return; + } + auto input = + read_bin(dir + "/dyn_linear.S" + std::to_string(m_rows) + ".input.bin"); + auto golden = + read_bin(dir + "/dyn_linear.S" + std::to_string(m_rows) + ".golden.bin"); + if (input.empty()) { + printf(" MISSING dyn_linear.S%d\n", m_rows); + ok = false; + return; + } + auto t = make_tensor_ptr({m_rows, kLinK}, std::move(input)); + auto r = m.forward({EValue(t)}); + if (!r.ok() || r.get().empty() || !r.get()[0].isTensor()) { + printf(" linear M=%d forward FAILED\n", m_rows); + ok = false; + return; + } + const auto& out = r.get()[0].toTensor(); + const size_t numel = static_cast(m_rows) * kLinN; + std::vector got( + out.const_data_ptr(), out.const_data_ptr() + numel); + float e = max_err(got, golden); + // 4-bit quant: looser tol (the kernel mirrors the dequant-matmul reference). + bool pass = out.numel() == static_cast(numel) && e < 5e-3f; + printf( + " dyn_linear M=%-3d max_err=%e -> %s\n", + m_rows, + e, + pass ? "PASS" : "FAIL"); + ok = ok && pass; +} + +// Dynamic SDPA (GQA prefill, input_pos=0): q[1,s,hq,d] k/v[1,s,hkv,d] +// caches[1,cmax,hkv,d]; attn output [1,s,hq,d] selected by shape (3 outputs). +constexpr int kSdHq = 8, kSdHkv = 2, kSdD = 16, kSdCmax = 64; +void check_sdpa(const std::string& dir, int s, bool& ok) { + Module m(dir + "/sdpa_dyn.pte"); + Error le = m.load_forward(); + if (le == Error::DelegateInvalidCompatibility) { + // PENDING op coverage: dynamic-S SDPA build throws err 48 until registered. + printf(" PENDING sdpa_dyn S=%d (op coverage, err %d)\n", s, (int)le); + return; + } + if (le != Error::Ok) { + printf(" sdpa_dyn S=%d load FAILED (err %d)\n", s, (int)le); + ok = false; + return; + } + const std::string b = dir + "/sdpa_dyn.S" + std::to_string(s) + "."; + auto q = read_bin(b + "q.bin"); + auto k = read_bin(b + "k.bin"); + auto v = read_bin(b + "v.bin"); + auto kc = read_bin(b + "kc.bin"); + auto vc = read_bin(b + "vc.bin"); + auto golden = read_bin(b + "golden.bin"); + if (q.empty() || golden.empty()) { + printf(" MISSING sdpa_dyn.S%d\n", s); + ok = false; + return; + } + auto tq = make_tensor_ptr({1, s, kSdHq, kSdD}, std::move(q)); + auto tk = make_tensor_ptr({1, s, kSdHkv, kSdD}, std::move(k)); + auto tv = make_tensor_ptr({1, s, kSdHkv, kSdD}, std::move(v)); + auto tkc = make_tensor_ptr({1, kSdCmax, kSdHkv, kSdD}, std::move(kc)); + auto tvc = make_tensor_ptr({1, kSdCmax, kSdHkv, kSdD}, std::move(vc)); + auto r = + m.forward({EValue(tq), EValue(tk), EValue(tv), EValue(tkc), EValue(tvc)}); + if (!r.ok()) { + printf(" sdpa S=%d forward FAILED (err=%d)\n", s, (int)r.error()); + ok = false; + return; + } + // Select the attn output by full shape [1,s,hq,d] (never numel). + const float* attn = nullptr; + size_t numel = static_cast(s) * kSdHq * kSdD; + for (size_t i = 0; i < r.get().size(); i++) { + if (!r.get()[i].isTensor()) { + continue; + } + const auto& t = r.get()[i].toTensor(); + if (t.dim() == 4 && t.size(1) == s && t.size(2) == kSdHq && + t.size(3) == kSdD) { + attn = t.const_data_ptr(); + break; + } + } + if (attn == nullptr) { + printf( + " sdpa S=%d: no attn output of shape [1,%d,%d,%d]\n", + s, + s, + kSdHq, + kSdD); + ok = false; + return; + } + std::vector got(attn, attn + numel); + float e = max_err(got, golden); + bool pass = e < 2e-3f; // SDPA tol (abs 1e-4 / rel 1e-3 family) + printf(" sdpa_dyn S=%-3d max_err=%e -> %s\n", s, e, pass ? "PASS" : "FAIL"); + ok = ok && pass; +} + +// Dynamic embedding: int64 token ids [N] -> [N, kEmbDim] fp32. The int64 host +// input exercises copy_inputs' int64->int32 narrow path under dynamic shapes. +constexpr int kEmbDim = 64; +void check_embedding(const std::string& dir, int n, bool& ok) { + Module m(dir + "/emb_dyn.pte"); + if (m.load_forward() != Error::Ok) { + printf(" FAIL load emb_dyn.pte\n"); + ok = false; + return; + } + const std::string b = dir + "/emb_dyn.S" + std::to_string(n) + "."; + std::ifstream f(b + "idx.bin", std::ios::binary | std::ios::ate); + if (!f) { + printf(" MISSING emb_dyn.S%d\n", n); + ok = false; + return; + } + const std::streamsize nb = f.tellg(); + if (nb < 0) { + printf(" MISSING emb_dyn.S%d\n", n); + ok = false; + return; + } + f.seekg(0); + std::vector idx(static_cast(nb) / sizeof(int64_t)); + f.read(reinterpret_cast(idx.data()), nb); + if (idx.size() != static_cast(n)) { + printf(" WRONG emb_dyn idx size S%d\n", n); + ok = false; + return; + } + auto golden = read_bin(b + "golden.bin"); + auto t = make_tensor_ptr({n}, std::move(idx)); // int64 (Long) host input + auto r = m.forward({EValue(t)}); + if (!r.ok() || r.get().empty() || !r.get()[0].isTensor()) { + printf( + " emb N=%d forward FAILED (err=%d)\n", n, r.ok() ? 0 : (int)r.error()); + ok = false; + return; + } + const auto& out = r.get()[0].toTensor(); + const size_t numel = static_cast(n) * kEmbDim; + std::vector got( + out.const_data_ptr(), out.const_data_ptr() + numel); + float e = max_err(got, golden); + bool pass = out.numel() == static_cast(numel) && e < 5e-3f; + printf(" emb_dyn N=%-3d max_err=%e -> %s\n", n, e, pass ? "PASS" : "FAIL"); + ok = ok && pass; +} + +// Dynamic RoPE: xq[1,s,nh,hd] xk[1,s,nkv,hd] freqs[s,hd/2] -> xq_out/xk_out +// (2 outputs, selected by head count nh != nkv). +constexpr int kRopeNH = 8, kRopeNKV = 2, kRopeHD = 64; +void check_rope(const std::string& dir, int s, bool& ok) { + Module m(dir + "/rope_dyn.pte"); + if (m.load_forward() != Error::Ok) { + printf(" FAIL load rope_dyn.pte\n"); + ok = false; + return; + } + const std::string b = dir + "/rope_dyn.S" + std::to_string(s) + "."; + auto xq = read_bin(b + "xq.bin"); + auto xk = read_bin(b + "xk.bin"); + auto fc = read_bin(b + "fc.bin"); + auto fs = read_bin(b + "fs.bin"); + auto gq = read_bin(b + "gq.bin"); + auto gk = read_bin(b + "gk.bin"); + if (xq.empty() || gq.empty()) { + printf(" MISSING rope_dyn.S%d\n", s); + ok = false; + return; + } + auto txq = make_tensor_ptr({1, s, kRopeNH, kRopeHD}, std::move(xq)); + auto txk = make_tensor_ptr({1, s, kRopeNKV, kRopeHD}, std::move(xk)); + auto tfc = make_tensor_ptr({s, kRopeHD / 2}, std::move(fc)); + auto tfs = make_tensor_ptr({s, kRopeHD / 2}, std::move(fs)); + auto r = m.forward({EValue(txq), EValue(txk), EValue(tfc), EValue(tfs)}); + if (!r.ok()) { + printf(" rope S=%d forward FAILED (err=%d)\n", s, (int)r.error()); + ok = false; + return; + } + // Select xq_out (nh heads) and xk_out (nkv heads) by shape. + const float *oq = nullptr, *okp = nullptr; + for (size_t i = 0; i < r.get().size(); i++) { + if (!r.get()[i].isTensor()) { + continue; + } + const auto& t = r.get()[i].toTensor(); + if (t.dim() == 4 && t.size(1) == s && t.size(3) == kRopeHD) { + if (t.size(2) == kRopeNH) { + oq = t.const_data_ptr(); + } else if (t.size(2) == kRopeNKV) { + okp = t.const_data_ptr(); + } + } + } + if (oq == nullptr || okp == nullptr) { + printf(" rope S=%d: missing xq_out/xk_out by shape\n", s); + ok = false; + return; + } + std::vector gotq(oq, oq + static_cast(s) * kRopeNH * kRopeHD); + std::vector gotk( + okp, okp + static_cast(s) * kRopeNKV * kRopeHD); + float e = std::fmax(max_err(gotq, gq), max_err(gotk, gk)); + bool pass = e < 1e-3f; + printf(" rope_dyn S=%-3d max_err=%e -> %s\n", s, e, pass ? "PASS" : "FAIL"); + ok = ok && pass; +} + +// Dynamic select_copy(0,-1): input [2,1,S,kHidden] -> output [1,S,kHidden]. The +// negative index resolves against the (static) leading dim live; the dynamic S +// flows to the output, so the resize hook recomputes its dispatch each S. +constexpr int kSelLead = 2; +void check_select(const std::string& dir, int s, bool& ok) { + Module m(dir + "/dyn_select.pte"); + if (m.load_forward() != Error::Ok) { + printf(" FAIL load dyn_select.pte\n"); + ok = false; + return; + } + auto input = + read_bin(dir + "/dyn_select.S" + std::to_string(s) + ".input.bin"); + auto golden = + read_bin(dir + "/dyn_select.S" + std::to_string(s) + ".golden.bin"); + if (input.empty() || golden.empty()) { + printf(" MISSING dyn_select.S%d\n", s); + ok = false; + return; + } + auto t = make_tensor_ptr({kSelLead, 1, s, kHidden}, std::move(input)); + auto r = m.forward({EValue(t)}); + if (!r.ok() || r.get().empty() || !r.get()[0].isTensor()) { + printf( + " select S=%d forward FAILED (err=%d)\n", + s, + r.ok() ? 0 : (int)r.error()); + ok = false; + return; + } + const auto& out = r.get()[0].toTensor(); + const size_t numel = static_cast(s) * kHidden; + std::vector got( + out.const_data_ptr(), out.const_data_ptr() + numel); + float e = max_err(got, golden); + bool pass = out.numel() == static_cast(numel) && e < 1e-3f; + printf( + " dyn_select S=%-3d max_err=%e -> %s\n", s, e, pass ? "PASS" : "FAIL"); + ok = ok && pass; +} + +} // namespace + +int main(int argc, char** argv) { + std::string dir = "/tmp/dynamic_shape"; + if (argc > 1) { + dir = argv[1]; + } + if (const char* env = std::getenv("WEBGPU_DYNAMIC_SHAPE_DIR")) { + dir = env; + } + + WebGPUContext ctx; + try { + ctx = create_webgpu_context(); + } catch (const std::exception& e) { + printf("SKIP: %s\n", e.what()); + return 0; + } + set_default_webgpu_context(&ctx); + printf("WebGPU device acquired (native); dir: %s\n", dir.c_str()); + + bool ok = true; + + // Cases A + B: single dynamic rms_norm at S = MAXS .. 1 (fresh module each). + printf("\n--- A/B: dynamic rms_norm at several S (fresh load each) ---\n"); + for (int s : {128, 64, 8, 1}) { + Module m(dir + "/dyn_rms.pte"); + if (m.load_forward() != Error::Ok) { + printf(" FAIL load dyn_rms.pte\n"); + ok = false; + break; + } + check_s(m, dir, "dyn_rms", s, ok); + } + + // Case C: ONE loaded graph reused across S (buffers must not move). + printf("\n--- C: one graph reused across S (bind groups stay valid) ---\n"); + { + Module m(dir + "/dyn_rms.pte"); + if (m.load_forward() != Error::Ok) { + printf(" FAIL load dyn_rms.pte\n"); + ok = false; + } else { + for (int s : {128, 1, 64, 8, 128}) { + check_s(m, dir, "dyn_rms", s, ok); + } + } + } + + // Case D: static rms_norm (no dynamic dim) — regression. + printf("\n--- D: static rms_norm (static path unchanged) ---\n"); + { + Module m(dir + "/static_rms.pte"); + if (m.load_forward() != Error::Ok) { + printf(" FAIL load static_rms.pte\n"); + ok = false; + } else { + check_s(m, dir, "static_rms", 8, ok); + } + } + + // Case F: 2-op chain rms(rms(x)) — resize cascade. + printf("\n--- F: rms(rms(x)) cascade at several S ---\n"); + for (int s : {128, 16, 1}) { + Module m(dir + "/dyn_rms_chain.pte"); + if (m.load_forward() != Error::Ok) { + printf(" FAIL load dyn_rms_chain.pte\n"); + ok = false; + break; + } + check_s(m, dir, "dyn_rms_chain", s, ok); + } + + // Case G: rms(x)+x residual — cross-op (rms -> add) cascade. + printf("\n--- G: rms(x)+x residual (rms->add cascade) at several S ---\n"); + for (int s : {128, 32, 1}) { + Module m(dir + "/dyn_residual.pte"); + if (m.load_forward() != Error::Ok) { + printf(" FAIL load dyn_residual.pte\n"); + ok = false; + break; + } + check_s(m, dir, "dyn_residual", s, ok); + } + + // Case H: rms(x)*x — exercises the mul op resize. + printf("\n--- H: rms(x)*x (mul op) at several S ---\n"); + for (int s : {128, 32, 1}) { + Module m(dir + "/dyn_rmsmul.pte"); + if (m.load_forward() != Error::Ok) { + printf(" FAIL load dyn_rmsmul.pte\n"); + ok = false; + break; + } + check_s(m, dir, "dyn_rmsmul", s, ok); + } + + // Case I: dynamic 4-bit quantized linear (prefill GEMM) at several M. + printf("\n--- I: dynamic linear_q4gsw [M,64]->[M,128] at several M ---\n"); + for (int mrows : {128, 32, 1}) { + check_linear(dir, mrows, ok); + } + + // Case J: dynamic SDPA (GQA prefill) at several seq-len S. + printf("\n--- J: dynamic sdpa_with_kv_cache (prefill) at several S ---\n"); + for (int s : {64, 16, 1}) { + check_sdpa(dir, s, ok); + } + + // Case K: dynamic embedding (int64 token ids) at several token counts. + printf("\n--- K: dynamic embedding_q4gsw (int64 ids) at several N ---\n"); + for (int n : {16, 8, 1}) { + check_embedding(dir, n, ok); + } + + // Case L: dynamic RoPE (two outputs) at several seq-len S. + printf("\n--- L: dynamic apply_rotary_emb at several S ---\n"); + for (int s : {16, 8, 1}) { + check_rope(dir, s, ok); + } + + // Case M: dynamic sigmoid (elementwise) at several S. + printf("\n--- M: dynamic sigmoid at several S ---\n"); + for (int s : {128, 32, 1}) { + Module m(dir + "/dyn_sigmoid.pte"); + if (m.load_forward() != Error::Ok) { + printf(" FAIL load dyn_sigmoid.pte\n"); + ok = false; + break; + } + check_s(m, dir, "dyn_sigmoid", s, ok); + } + + // Case N: dynamic select_copy(0,-1) at several S. + printf("\n--- N: dynamic select_copy(0,-1) at several S ---\n"); + for (int s : {128, 32, 1}) { + check_select(dir, s, ok); + } + + set_default_webgpu_context(nullptr); + destroy_webgpu_context(ctx); + + if (!ok) { + printf("\ndynamic_shape tests FAILED\n"); + return 1; + } + printf("\nAll dynamic_shape tests passed\n"); + return 0; +} diff --git a/backends/webgpu/test/ops/dynamic_shape/test_dynamic_shape_export.py b/backends/webgpu/test/ops/dynamic_shape/test_dynamic_shape_export.py new file mode 100644 index 00000000000..6652d073805 --- /dev/null +++ b/backends/webgpu/test/ops/dynamic_shape/test_dynamic_shape_export.py @@ -0,0 +1,454 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Dynamic tensor-shape (Option 2) export tests via VulkanPartitioner. + +Exports ONE graph built at the upper-bound seq-len MAXS that the native runtime +(`test/native/test_dynamic_shape.cpp`) then runs at several live S, asserting the +output matches the torch golden and that the static path is unchanged. Numerics +are checked in the native test; this verifies the dynamic export side + writes +goldens. +""" + +import os +import unittest + +import torch +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner +from executorch.exir import to_edge_transform_and_lower + +MAXS = 128 # upper bound for the dynamic seq-len dim (within the 1D dispatch cap) +HIDDEN = 64 + + +def _rms(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + x_f32 = x.to(torch.float32) + var = x_f32.pow(2).mean(dim=-1, keepdim=True) + return (x_f32 * torch.rsqrt(var + eps)) * weight + + +class RmsNormModule(torch.nn.Module): + def __init__(self, hidden: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = torch.nn.Parameter( + torch.linspace(0.5, 1.5, hidden, dtype=torch.float32) + ) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _rms(x, self.weight, self.eps) + + +class RmsChainModule(torch.nn.Module): + """rms(rms(x)) — two ops; exercises the resize-cascade (DD-4).""" + + def __init__(self, hidden: int, eps: float = 1e-6) -> None: + super().__init__() + self.w1 = torch.nn.Parameter( + torch.linspace(0.5, 1.5, hidden, dtype=torch.float32) + ) + self.w2 = torch.nn.Parameter( + torch.linspace(1.5, 0.5, hidden, dtype=torch.float32) + ) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _rms(_rms(x, self.w1, self.eps), self.w2, self.eps) + + +class RmsResidualModule(torch.nn.Module): + """rms(x) + x — rms op feeding an add op; proves the cross-op resize cascade.""" + + def __init__(self, hidden: int, eps: float = 1e-6) -> None: + super().__init__() + self.w = torch.nn.Parameter( + torch.linspace(0.5, 1.5, hidden, dtype=torch.float32) + ) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _rms(x, self.w, self.eps) + x + + +class RmsMulModule(torch.nn.Module): + """rms(x) * x — exercises the mul op (two same-shape dynamic operands).""" + + def __init__(self, hidden: int, eps: float = 1e-6) -> None: + super().__init__() + self.w = torch.nn.Parameter( + torch.linspace(0.5, 1.5, hidden, dtype=torch.float32) + ) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _rms(x, self.w, self.eps) * x + + +class SigmoidModule(torch.nn.Module): + """sigmoid(x) — elementwise; resize hook recomputes dispatch from live numel.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.sigmoid(x) + + +class SelectModule(torch.nn.Module): + """x.select(0, -1) — negative index resolved live + dynamic output dispatch.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.select(0, -1) + + +def _ramp(shape) -> torch.Tensor: + n = 1 + for d in shape: + n *= d + return torch.linspace(-1.0, 1.0, n, dtype=torch.float32).reshape(shape) + + +def _export(model, example_inputs, dynamic_shapes, path: str) -> None: + ep = torch.export.export(model, example_inputs, dynamic_shapes=dynamic_shapes) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + found = any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ) + assert found, f"Expected VulkanBackend delegate in {path}" + with open(path, "wb") as f: + f.write(et.buffer) + print(f"Exported {path}") + + +def _write_goldens(model, prefix: str, out_dir: str, s_values) -> None: + for s in s_values: + x = _ramp((1, 1, s, HIDDEN)) + with torch.no_grad(): + g = model(x) + x.detach().numpy().astype(" None: + """Write the dynamic + static .pte's and per-S goldens for the native test.""" + os.makedirs(out_dir, exist_ok=True) + s_dim = torch.export.Dim("s", min=1, max=MAXS) + + # 1) Single dynamic rms_norm, graph built at S=MAXS (upper bound). + rms = RmsNormModule(HIDDEN) + _export( + rms, + (_ramp((1, 1, MAXS, HIDDEN)),), + {"x": {2: s_dim}}, + os.path.join(out_dir, "dyn_rms.pte"), + ) + _write_goldens(rms, "dyn_rms", out_dir, [MAXS, 64, 8, 1]) + + # 2) Two-op chain (cascade): rms(rms(x)). + chain = RmsChainModule(HIDDEN) + _export( + chain, + (_ramp((1, 1, MAXS, HIDDEN)),), + {"x": {2: s_dim}}, + os.path.join(out_dir, "dyn_rms_chain.pte"), + ) + _write_goldens(chain, "dyn_rms_chain", out_dir, [MAXS, 16, 1]) + + # 2b) rms(x)+x residual — cross-op (rms->add) cascade. + resid = RmsResidualModule(HIDDEN) + _export( + resid, + (_ramp((1, 1, MAXS, HIDDEN)),), + {"x": {2: s_dim}}, + os.path.join(out_dir, "dyn_residual.pte"), + ) + _write_goldens(resid, "dyn_residual", out_dir, [MAXS, 32, 1]) + + # 2c) rms(x)*x — exercises the mul op resize. + rmsmul = RmsMulModule(HIDDEN) + _export( + rmsmul, + (_ramp((1, 1, MAXS, HIDDEN)),), + {"x": {2: s_dim}}, + os.path.join(out_dir, "dyn_rmsmul.pte"), + ) + _write_goldens(rmsmul, "dyn_rmsmul", out_dir, [MAXS, 32, 1]) + + # 2d) 4-bit quantized linear with a DYNAMIC rows (M) dim — prefill GEMM. + _export_dynamic_linear(out_dir) + + # 2e) Fused SDPA with a DYNAMIC seq-len S (prefill, input_pos=0). + _export_dynamic_sdpa(out_dir) + + # 2f) 4-bit embedding with a DYNAMIC token count (int64 indices). + _export_dynamic_embedding(out_dir) + + # 2g) Interleaved RoPE with a DYNAMIC seq-len S (two outputs xq/xk). + _export_dynamic_rope(out_dir) + + # 2h) Elementwise sigmoid with a DYNAMIC seq-len S. + sig = SigmoidModule() + _export( + sig, + (_ramp((1, 1, MAXS, HIDDEN)),), + {"x": {2: s_dim}}, + os.path.join(out_dir, "dyn_sigmoid.pte"), + ) + _write_goldens(sig, "dyn_sigmoid", out_dir, [MAXS, 32, 1]) + + # 2i) select_copy(0, -1) over a DYNAMIC seq-len S (negative live index). + _export_dynamic_select(out_dir) + + # 3) Static rms_norm (no dynamic dim) — regression: must stay byte-identical. + static = RmsNormModule(HIDDEN) + _export( + static, + (_ramp((1, 1, 8, HIDDEN)),), + None, + os.path.join(out_dir, "static_rms.pte"), + ) + _write_goldens(static, "static_rms", out_dir, [8]) + + +# Quantized linear: K x N weight, dynamic rows M; input [M, K], output [M, N]. +LIN_K = 64 +LIN_N = 128 +LIN_GROUP = 32 +LIN_MAXM = 128 + + +def _export_dynamic_linear(out_dir: str) -> None: + from executorch.backends.webgpu.test.ops.quantized_linear.test_quantized_linear import ( + _fp64_golden, + _make_quantized_model, + ) + + model = _make_quantized_model(LIN_K, LIN_N, LIN_GROUP) + x = _ramp((LIN_MAXM, LIN_K)) + m_dim = torch.export.Dim("m", min=1, max=LIN_MAXM) + ep = torch.export.export(model, (x,), dynamic_shapes=({0: m_dim},)) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + assert any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ), "linear_q4gsw not delegated" + with open(os.path.join(out_dir, "dyn_linear.pte"), "wb") as f: + f.write(et.buffer) + print("Exported dyn_linear.pte") + for m in [LIN_MAXM, 32, 1]: + xm = _ramp((m, LIN_K)) + g = _fp64_golden(model, xm).astype(" None: + from executorch.backends.webgpu.test.ops.sdpa.test_sdpa import ( + _det_inputs, + _golden, + SdpaConfig, + SdpaModule, + ) + + def cfg(s: int) -> "SdpaConfig": + return SdpaConfig("dyn", SD_HQ, SD_HKV, SD_D, s, SD_CMAX, 0) + + model = SdpaModule(0) + q, k, v, kc, vc = _det_inputs(cfg(SD_MAXS)) + s_dim = torch.export.Dim("s", min=1, max=SD_MAXS) + ds = ({1: s_dim}, {1: s_dim}, {1: s_dim}, None, None) + ep = torch.export.export(model, (q, k, v, kc, vc), dynamic_shapes=ds) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + assert any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ), "sdpa not delegated" + with open(os.path.join(out_dir, "sdpa_dyn.pte"), "wb") as f: + f.write(et.buffer) + print("Exported sdpa_dyn.pte") + for s in [SD_MAXS, 16, 1]: + c = cfg(s) + q, k, v, kc, vc = _det_inputs(c) + g = _golden(c, q, k, v, kc, vc) + for name, t in [ + ("q", q), + ("k", k), + ("v", v), + ("kc", kc), + ("vc", vc), + ("golden", g), + ]: + t.detach().cpu().numpy().astype(" [N, EMBED] fp32. +EMB_VOCAB = 64 +EMB_DIM = 64 +EMB_GROUP = 32 +EMB_MAXN = 16 + + +def _export_dynamic_embedding(out_dir: str) -> None: + from executorch.backends.webgpu.test.ops.embedding_q4gsw.test_embedding_q4gsw import ( + _make_quantized_model, + _quant_params, + Shape, + ) + + shape = Shape("dyn", EMB_VOCAB, EMB_DIM, EMB_GROUP, list(range(EMB_MAXN))) + qm = _make_quantized_model(shape) + idx_max = torch.arange(EMB_MAXN, dtype=torch.long) + n_dim = torch.export.Dim("n", min=1, max=EMB_MAXN) + ep = torch.export.export(qm, (idx_max,), dynamic_shapes=({0: n_dim},)) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + assert any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ), "embedding_q4gsw not delegated" + with open(os.path.join(out_dir, "emb_dyn.pte"), "wb") as f: + f.write(et.buffer) + print("Exported emb_dyn.pte") + weight, scales, group_size = _quant_params(qm) + for n in [EMB_MAXN, 8, 1]: + idx = (torch.arange(n, dtype=torch.long) * 7) % EMB_VOCAB + g = torch.ops.et_vk.embedding_q4gsw.default( + weight, scales, group_size, idx, False + ) + idx.detach().numpy().astype(" None: + from executorch.backends.webgpu.test.ops.rope.test_rope import ( + _golden, + _inputs, + Shape, + ) + from executorch.examples.models.llama.rope import RotaryEmbedding + + xq, xk, fc, fs = _inputs(Shape("dyn", 1, ROPE_MAXS, ROPE_NH, ROPE_NKV, ROPE_HD)) + s_dim = torch.export.Dim("s", min=1, max=ROPE_MAXS) + ds = ({1: s_dim}, {1: s_dim}, {0: s_dim}, {0: s_dim}) + ep = torch.export.export( + RotaryEmbedding().eval(), (xq, xk, fc, fs), dynamic_shapes=ds + ) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + assert any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ), "apply_rotary_emb not delegated" + with open(os.path.join(out_dir, "rope_dyn.pte"), "wb") as f: + f.write(et.buffer) + print("Exported rope_dyn.pte") + for s in [ROPE_MAXS, 8, 1]: + xq, xk, fc, fs = _inputs(Shape("dyn", 1, s, ROPE_NH, ROPE_NKV, ROPE_HD)) + gq, gk = _golden(xq, xk, fc, fs) + for name, t in [ + ("xq", xq), + ("xk", xk), + ("fc", fc), + ("fs", fs), + ("gq", gq), + ("gk", gk), + ]: + t.detach().cpu().numpy().astype(" [1, S, HIDDEN]. +SEL_LEAD = 2 + + +def _export_dynamic_select(out_dir: str) -> None: + model = SelectModule() + s_dim = torch.export.Dim("s", min=1, max=MAXS) + ep = torch.export.export( + model, + (_ramp((SEL_LEAD, 1, MAXS, HIDDEN)),), + dynamic_shapes=({2: s_dim},), + ) + et = to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + assert any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ), "select_copy not delegated" + with open(os.path.join(out_dir, "dyn_select.pte"), "wb") as f: + f.write(et.buffer) + print("Exported dyn_select.pte") + for s in [MAXS, 32, 1]: + x = _ramp((SEL_LEAD, 1, s, HIDDEN)) + with torch.no_grad(): + g = model(x) + x.detach().numpy().astype(" None: + import tempfile + + with tempfile.TemporaryDirectory() as d: + export_dynamic_shape_cases(d) + self.assertTrue(os.path.exists(os.path.join(d, "dyn_rms.pte"))) + self.assertTrue(os.path.exists(os.path.join(d, "dyn_rms.S1.golden.bin"))) + + +if __name__ == "__main__": + unittest.main()