Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>

#include <algorithm>
#include <cstdint>
#include <functional>
#include <stdexcept>

namespace executorch::backends::webgpu {
Expand Down Expand Up @@ -38,10 +40,77 @@ void select_as_symint_impl(WebGPUGraph& graph, const std::vector<int>& args) {
static_cast<int>(graph.get_int(index_id)));
}

// An operand is a live SymInt or a static Int constant.
int32_t read_scalar(WebGPUGraph& graph, int id) {
if (graph.get_value_type(id) == WebGPUGraph::ValueType::SymInt) {
return graph.read_symint(id);
}
return static_cast<int32_t>(graph.get_int(id));
}

// SymInt arithmetic; mirrors Vulkan SymIntOps.cpp, recomputed on resize.
void register_sym_binary(
WebGPUGraph& graph,
const std::vector<int>& args,
std::function<int32_t(int32_t, int32_t)> op) {
if (args.size() < 3) {
throw std::runtime_error("sym binary op: expected [a, b, out] args");
}
const int a = args.at(0);
const int b = args.at(1);
const int out = args.at(2);
if (graph.get_value_type(out) != WebGPUGraph::ValueType::SymInt) {
return; // folded to a static Int -> nothing live to compute
}
auto recompute = [a, b, out, op](WebGPUGraph& g) {
g.set_symint(out, op(read_scalar(g, a), read_scalar(g, b)));
};
recompute(graph); // seed the build-time value
if (graph.get_value_type(a) == WebGPUGraph::ValueType::SymInt) {
graph.add_resize_hook(a, recompute);
}
if (graph.get_value_type(b) == WebGPUGraph::ValueType::SymInt) {
graph.add_resize_hook(b, recompute);
}
}

void sym_add_impl(WebGPUGraph& graph, const std::vector<int>& args) {
register_sym_binary(graph, args, [](int32_t x, int32_t y) { return x + y; });
}

void sym_sub_impl(WebGPUGraph& graph, const std::vector<int>& args) {
register_sym_binary(graph, args, [](int32_t x, int32_t y) { return x - y; });
}

void sym_mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
register_sym_binary(graph, args, [](int32_t x, int32_t y) { return x * y; });
}

void sym_floordiv_impl(WebGPUGraph& graph, const std::vector<int>& args) {
register_sym_binary(graph, args, [](int32_t x, int32_t y) {
if (y == 0) {
throw std::runtime_error("sym floordiv: division by zero");
}
if (x == INT32_MIN && y == -1) {
throw std::runtime_error(
"sym floordiv: signed overflow (INT32_MIN / -1)");
}
int32_t q = x / y;
if ((x % y != 0) && ((x < 0) != (y < 0))) {
q--; // round toward negative infinity (Python floor division)
}
return q;
});
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint_impl);
WEBGPU_REGISTER_OP(add, sym_add_impl);
WEBGPU_REGISTER_OP(sub, sym_sub_impl);
WEBGPU_REGISTER_OP(mul, sym_mul_impl);
WEBGPU_REGISTER_OP(floordiv, sym_floordiv_impl);
}

} // namespace executorch::backends::webgpu
Loading