Skip to content

Commit 6dd6ef9

Browse files
[Native WebGPU] Added ReduceMax and ReduceSum (#23934)
### Description Added ReduceMax and ReduceSum ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 9891eb3 commit 6dd6ef9

File tree

3 files changed

+112
-37
lines changed

3 files changed

+112
-37
lines changed

onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,28 @@ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12);
3434
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 13, 17);
3535
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMean, 18);
3636

37+
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 1, 10);
38+
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 11, 11);
39+
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 12, 12);
40+
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 13, 17);
41+
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMax, 18);
42+
43+
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSum, 1, 10);
44+
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSum, 11, 12);
45+
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSum, 13);
46+
3747
Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const {
38-
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
3948
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
49+
if (is_input_empty_) {
50+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
51+
<< code_[0]
52+
<< code_[2]
53+
<< output.SetByOffset("global_idx", "output_value");
54+
return Status::OK();
55+
}
56+
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
4057
bool reduce_on_all_axes = no_op_with_empty_axes_ == false && axes_.empty();
41-
std::string loop_header = code_[0];
58+
std::string loop_header = code_[0].find("first_element") == std::string::npos ? code_[0] : "let first_element = " + input.GetByIndices("input_indices") + ";\n" + code_[0] + "\n";
4259
std::string loop_body = "let current_element: input_value_t = " + input.GetByIndices("input_indices") + ";\n" + code_[1];
4360
std::string loop_footer = code_[2];
4461
const auto input_rank = input.Rank();
@@ -56,10 +73,10 @@ Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const {
5673
loop_body = ss.str();
5774
} else {
5875
std::stringstream ss;
59-
ss << loop_header << "\n";
6076
std::string index = "i" + std::to_string(i);
6177
ss << "let " << index << " = " << output.IndicesGet("output_indices", l) << ";\n";
6278
ss << input.IndicesSet("input_indices", i, index) << ";\n";
79+
ss << loop_header << "\n";
6380
loop_header = ss.str();
6481
l++;
6582
}
@@ -80,6 +97,7 @@ Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const {
8097
template <bool allow_multi_axes>
8198
Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context) const {
8299
const auto* input_tensor = context.Input(0);
100+
ORT_RETURN_IF_ERROR(CheckInput(input_tensor));
83101
InlinedVector<uint32_t> input_axes;
84102
auto rank = input_tensor->Shape().NumDimensions();
85103
auto transform_axis = [rank](int64_t axis) {
@@ -95,10 +113,12 @@ Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context)
95113
if (context.InputCount() > 1) {
96114
ORT_ENFORCE(axes_.empty(), "Axes attribute may not be specified when axes input is also provided.");
97115
const Tensor* axes_tensor = context.Input<Tensor>(1);
98-
auto size = static_cast<size_t>(axes_tensor->Shape()[0]);
99-
const auto* data = axes_tensor->Data<int64_t>();
100-
input_axes.reserve(size);
101-
std::transform(data, data + size, std::back_inserter(input_axes), transform_axis);
116+
if (nullptr != axes_tensor) {
117+
auto size = static_cast<size_t>(axes_tensor->Shape()[0]);
118+
const auto* data = axes_tensor->Data<int64_t>();
119+
input_axes.reserve(size);
120+
std::transform(data, data + size, std::back_inserter(input_axes), transform_axis);
121+
}
102122
} else {
103123
input_axes.reserve(axes_.size());
104124
std::transform(axes_.begin(), axes_.end(), std::back_inserter(input_axes), transform_axis);
@@ -120,10 +140,12 @@ Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context)
120140
std::iota(input_axes.begin(), input_axes.end(), 0);
121141
}
122142
}
123-
const auto code = GetOpSpecificCode(input_tensor, input_axes.size());
143+
const auto code = GetOpSpecificCode(input_tensor);
124144
// Compute output shape
125145
std::vector<int64_t> output_shape;
146+
bool is_input_empty = false;
126147
for (size_t i = 0; i < input_tensor->Shape().NumDimensions(); ++i) {
148+
is_input_empty |= input_tensor->Shape()[i] == 0;
127149
if (std::find(input_axes.begin(), input_axes.end(), i) != input_axes.end()) {
128150
if (keepdims_) {
129151
output_shape.push_back(1);
@@ -134,34 +156,68 @@ Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context)
134156
}
135157
TensorShape output_tensor_shape(output_shape);
136158
int64_t output_size = output_tensor_shape.Size();
137-
ReduceKernelProgram program("ReduceMean", keepdims_, noop_with_empty_axes_, input_axes, code);
138-
program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank})
159+
if (output_size == 0) {
160+
ORT_IGNORE_RETURN_VALUE(context.Output(0, output_tensor_shape));
161+
return Status::OK();
162+
}
163+
164+
auto input_rank = input_tensor->Shape().NumDimensions();
165+
// reduce_axes element is either 1 or 0 depending on whether the axis is reduced or not
166+
std::vector<uint32_t> reduce_axes;
167+
reduce_axes.resize(input_rank, 0);
168+
for (auto axis : input_axes) {
169+
reduce_axes[axis] = 1;
170+
}
171+
172+
ReduceKernelProgram program(name_, keepdims_, noop_with_empty_axes_, input_axes, code, is_input_empty);
173+
if (!is_input_empty) {
174+
program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank});
175+
}
176+
177+
program.CacheHint(is_input_empty)
139178
.AddOutput({context.Output(0, output_shape), ProgramTensorMetadataDependency::TypeAndRank})
140179
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
141180
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
142181
{static_cast<uint32_t>(noop_with_empty_axes_ ? 1 : 0)},
143-
{input_axes},
144-
{static_cast<uint32_t>(input_axes.size())}});
182+
{reduce_axes}});
145183

146184
return context.RunProgram(program);
147185
}
148186

149-
ReduceOpSpecificCode ReduceMean::GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const {
187+
ReduceOpSpecificCode ReduceMean::GetOpSpecificCode(const Tensor* input_tensor) const {
150188
const TensorShape& input_shape = input_tensor->Shape();
151189
size_t input_rank = input_shape.NumDimensions();
190+
std::string loop_header = "var sum = f32(0);";
191+
std::string loop_body = "sum += f32(current_element);";
152192
std::stringstream ss;
153193
ss << "var size: u32 = 1;\n"
154-
<< "for (var i: u32 = 0; i < uniforms.axes_size; i += 1) { \n"
155-
<< " let index = " << GetElementAt("uniforms.axes", "i", axes_size) << ";\n"
156-
<< " size = size * " << GetElementAt("uniforms.input_shape", "index", input_rank) << ";\n"
194+
<< "for (var i: u32 = 0; i < " << input_rank << "; i += 1) { \n"
195+
<< " let index_reduced_or_not = " << GetElementAt("uniforms.reduce_axes", "i", input_rank) << ";\n"
196+
<< " if (index_reduced_or_not == 1) { \n"
197+
<< " size = size * " << GetElementAt("uniforms.input_shape", "i", input_rank) << ";\n"
198+
<< " }\n"
157199
<< "}\n"
158200
<< "let output_value = output_value_t(sum / f32(size));";
159-
ReduceOpSpecificCode code({"var sum = f32(0);", "sum += f32(current_element);", ss.str()});
201+
std::string loop_footer = ss.str();
202+
ReduceOpSpecificCode code({loop_header, loop_body, loop_footer});
160203
return code;
161204
}
162205

163-
Status ReduceMean::ComputeInternal(ComputeContext& ctx) const {
164-
return ReduceKernel<true>::ComputeInternal(ctx);
206+
ReduceOpSpecificCode ReduceMax::GetOpSpecificCode(const Tensor* input_tensor) const {
207+
ORT_UNUSED_PARAMETER(input_tensor);
208+
std::string loop_header = "var max_element = first_element;";
209+
std::string loop_body = "max_element = max(max_element, current_element);";
210+
std::string loop_footer = "let output_value = output_value_t(max_element);";
211+
ReduceOpSpecificCode code({loop_header, loop_body, loop_footer});
212+
return code;
213+
}
214+
ReduceOpSpecificCode ReduceSum::GetOpSpecificCode(const Tensor* input_tensor) const {
215+
ORT_UNUSED_PARAMETER(input_tensor);
216+
std::string loop_header = "var sum = f32(0);";
217+
std::string loop_body = "sum += f32(current_element);";
218+
std::string loop_footer = "let output_value = output_value_t(sum);";
219+
ReduceOpSpecificCode code({loop_header, loop_body, loop_footer});
220+
return code;
165221
}
166222

167223
} // namespace webgpu

onnxruntime/core/providers/webgpu/reduction/reduction_ops.h

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,23 @@ namespace webgpu {
1313
// reduceOpSpecificCode is a 3-element array of strings that represent the op specific code for the reduce operation.
1414
// The first element is the loop header, the second element is the loop body, and the third element is the loop footer.
1515
// The loop header is the code that is executed before the loop starts. The loop body is the code that is executed for each element in the loop.
16-
// The loop footer is the code that is executed after the loop ends.
16+
// The loop footer is the code that is executed after the loop ends. The loop body should contain the code that accumulates the result of the reduction and
17+
// the loop footer should contain the code that assigins output_value the result of the reduction.
1718
typedef std::array<std::string, 3> ReduceOpSpecificCode;
1819
class ReduceKernelProgram final : public Program<ReduceKernelProgram> {
1920
public:
20-
ReduceKernelProgram(std::string name, bool keepdims, bool no_op_with_empty_axes, const InlinedVector<uint32_t>& axes, ReduceOpSpecificCode code) : Program{name}, keepdims_(keepdims), no_op_with_empty_axes_(no_op_with_empty_axes), axes_(axes.begin(), axes.end()), code_(code) {}
21+
ReduceKernelProgram(std::string name, bool keepdims, bool no_op_with_empty_axes, const InlinedVector<uint32_t>& axes, ReduceOpSpecificCode code, bool is_input_empty) : Program{name}, keepdims_(keepdims), no_op_with_empty_axes_(no_op_with_empty_axes), axes_(axes.begin(), axes.end()), code_(code), is_input_empty_(is_input_empty) {}
2122
Status GenerateShaderCode(ShaderHelper& wgpuShaderModuleAddRef) const override;
2223
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
2324
{"no_op_with_empty_axes", ProgramUniformVariableDataType::Uint32},
24-
{"axes", ProgramUniformVariableDataType::Uint32},
25-
{"axes_size", ProgramUniformVariableDataType::Uint32});
25+
{"reduce_axes", ProgramUniformVariableDataType::Uint32});
2626

2727
private:
2828
const bool keepdims_;
2929
const bool no_op_with_empty_axes_;
3030
InlinedVector<uint32_t> axes_;
3131
ReduceOpSpecificCode code_;
32+
bool is_input_empty_;
3233
};
3334

3435
template <bool allow_multi_axes = true>
@@ -39,23 +40,41 @@ class ReduceKernel : public WebGpuKernel, public ReduceKernelBase<allow_multi_ax
3940
using ReduceKernelBase<allow_multi_axes>::keepdims_;
4041
using ReduceKernelBase<allow_multi_axes>::select_last_index_;
4142

42-
ReduceKernel(const OpKernelInfo& info, std::string name, optional<int64_t> keepdims_override = {})
43+
ReduceKernel(const OpKernelInfo& info, std::string name, bool allow_empty_input = false, optional<int64_t> keepdims_override = {})
4344
: WebGpuKernel(info),
4445
ReduceKernelBase<allow_multi_axes>(info, keepdims_override),
45-
name_(name) {
46+
name_(name),
47+
allow_empty_input_(allow_empty_input) {
4648
}
4749
Status ComputeInternal(ComputeContext& ctx) const;
48-
virtual ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const = 0;
50+
virtual ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const = 0;
51+
52+
Status CheckInput(const Tensor* input_tensor) const {
53+
ORT_ENFORCE(input_tensor != nullptr && (input_tensor->Shape().Size() > 0 || allow_empty_input_), "Input tensor cannot be null or empty");
54+
return Status::OK();
55+
}
4956

5057
private:
5158
std::string name_;
59+
bool allow_empty_input_;
5260
};
5361

5462
class ReduceMean final : public ReduceKernel<true> {
5563
public:
56-
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info, "ReduceMean") {}
57-
ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const override;
58-
Status ComputeInternal(ComputeContext& ctx) const override;
64+
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info, "ReduceMean", true) {}
65+
ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override;
66+
};
67+
68+
class ReduceMax final : public ReduceKernel<true> {
69+
public:
70+
ReduceMax(const OpKernelInfo& info) : ReduceKernel<true>(info, "ReduceMax") {}
71+
ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override;
72+
};
73+
74+
class ReduceSum final : public ReduceKernel<true> {
75+
public:
76+
ReduceSum(const OpKernelInfo& info) : ReduceKernel<true>(info, "ReduceSum", true) {}
77+
ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override;
5978
};
6079

6180
} // namespace webgpu

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -513,11 +513,11 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
513513
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Squeeze)>,
514514
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,
515515
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Squeeze)>,
516-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMax)>,
517-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax)>,
518-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMax)>,
519-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMax)>,
520-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMax)>,
516+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMax)>,
517+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax)>,
518+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMax)>,
519+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMax)>,
520+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMax)>,
521521

522522
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMean)>,
523523
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceMean)>,
@@ -539,9 +539,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
539539
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceProd)>,
540540
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceProd)>,
541541

542-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSum)>,
543-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSum)>,
544-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ReduceSum)>,
542+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSum)>,
543+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSum)>,
544+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ReduceSum)>,
545545

546546
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceL1)>,
547547
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceL1)>,

0 commit comments

Comments
 (0)