From 438512c8707f64c8b5a6486490140ce5caa77e27 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 28 Jun 2026 09:22:01 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- .../ops/select_as_symint/SelectAsSymint.cpp | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp b/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp index 573a88ce0fe..aea2a19c058 100644 --- a/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp +++ b/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp @@ -10,6 +10,7 @@ #include #include +#include #include namespace executorch::backends::webgpu { @@ -38,10 +39,70 @@ void select_as_symint_impl(WebGPUGraph& graph, const std::vector& args) { static_cast(graph.get_int(index_id))); } +// An operand is a live SymInt or a static Int constant. +int32_t read_scalar(WebGPUGraph& graph, int id) { + if (graph.get_value_type(id) == WebGPUGraph::ValueType::SymInt) { + return graph.read_symint(id); + } + return static_cast(graph.get_int(id)); +} + +// SymInt arithmetic; mirrors Vulkan SymIntOps.cpp, recomputed on resize. +void register_sym_binary( + WebGPUGraph& graph, + const std::vector& args, + std::function op) { + if (args.size() < 3) { + throw std::runtime_error("sym binary op: expected [a, b, out] args"); + } + const int a = args.at(0); + const int b = args.at(1); + const int out = args.at(2); + if (graph.get_value_type(out) != WebGPUGraph::ValueType::SymInt) { + return; // folded to a static Int -> nothing live to compute + } + auto recompute = [a, b, out, op](WebGPUGraph& g) { + g.set_symint(out, op(read_scalar(g, a), read_scalar(g, b))); + }; + recompute(graph); // seed the build-time value + if (graph.get_value_type(a) == WebGPUGraph::ValueType::SymInt) { + graph.add_resize_hook(a, recompute); + } + if (graph.get_value_type(b) == WebGPUGraph::ValueType::SymInt) { + graph.add_resize_hook(b, recompute); + } +} + +void sym_add_impl(WebGPUGraph& graph, const std::vector& args) { + register_sym_binary(graph, args, [](int32_t x, int32_t y) { return x + y; }); +} + +void sym_sub_impl(WebGPUGraph& graph, const std::vector& args) { + register_sym_binary(graph, args, [](int32_t x, int32_t y) { return x - y; }); +} + +void sym_mul_impl(WebGPUGraph& graph, const std::vector& args) { + register_sym_binary(graph, args, [](int32_t x, int32_t y) { return x * y; }); +} + +void sym_floordiv_impl(WebGPUGraph& graph, const std::vector& args) { + register_sym_binary(graph, args, [](int32_t x, int32_t y) { + int32_t q = x / y; + if ((x % y != 0) && ((x < 0) != (y < 0))) { + q--; // round toward negative infinity (Python floor division) + } + return q; + }); +} + } // namespace WEBGPU_REGISTER_OPERATORS { WEBGPU_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint_impl); + WEBGPU_REGISTER_OP(add, sym_add_impl); + WEBGPU_REGISTER_OP(sub, sym_sub_impl); + WEBGPU_REGISTER_OP(mul, sym_mul_impl); + WEBGPU_REGISTER_OP(floordiv, sym_floordiv_impl); } } // namespace executorch::backends::webgpu