@@ -34,11 +34,28 @@ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12);
34
34
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMean, 13 , 17 );
35
35
REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceMean, 18 );
36
36
37
+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 1 , 10 );
38
+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 11 , 11 );
39
+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 12 , 12 );
40
+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 13 , 17 );
41
+ REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceMax, 18 );
42
+
43
+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceSum, 1 , 10 );
44
+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceSum, 11 , 12 );
45
+ REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceSum, 13 );
46
+
37
47
Status ReduceKernelProgram::GenerateShaderCode (ShaderHelper& shader) const {
38
- const auto & input = shader.AddInput (" input" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
39
48
const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
49
+ if (is_input_empty_) {
50
+ shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
51
+ << code_[0 ]
52
+ << code_[2 ]
53
+ << output.SetByOffset (" global_idx" , " output_value" );
54
+ return Status::OK ();
55
+ }
56
+ const auto & input = shader.AddInput (" input" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
40
57
bool reduce_on_all_axes = no_op_with_empty_axes_ == false && axes_.empty ();
41
- std::string loop_header = code_[0 ];
58
+ std::string loop_header = code_[0 ]. find ( " first_element " ) == std::string::npos ? code_[ 0 ] : " let first_element = " + input. GetByIndices ( " input_indices " ) + " ; \n " + code_[ 0 ] + " \n " ;
42
59
std::string loop_body = " let current_element: input_value_t = " + input.GetByIndices (" input_indices" ) + " ;\n " + code_[1 ];
43
60
std::string loop_footer = code_[2 ];
44
61
const auto input_rank = input.Rank ();
@@ -56,10 +73,10 @@ Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const {
56
73
loop_body = ss.str ();
57
74
} else {
58
75
std::stringstream ss;
59
- ss << loop_header << " \n " ;
60
76
std::string index = " i" + std::to_string (i);
61
77
ss << " let " << index << " = " << output.IndicesGet (" output_indices" , l) << " ;\n " ;
62
78
ss << input.IndicesSet (" input_indices" , i, index) << " ;\n " ;
79
+ ss << loop_header << " \n " ;
63
80
loop_header = ss.str ();
64
81
l++;
65
82
}
@@ -80,6 +97,7 @@ Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const {
80
97
template <bool allow_multi_axes>
81
98
Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context) const {
82
99
const auto * input_tensor = context.Input (0 );
100
+ ORT_RETURN_IF_ERROR (CheckInput (input_tensor));
83
101
InlinedVector<uint32_t > input_axes;
84
102
auto rank = input_tensor->Shape ().NumDimensions ();
85
103
auto transform_axis = [rank](int64_t axis) {
@@ -95,10 +113,12 @@ Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context)
95
113
if (context.InputCount () > 1 ) {
96
114
ORT_ENFORCE (axes_.empty (), " Axes attribute may not be specified when axes input is also provided." );
97
115
const Tensor* axes_tensor = context.Input <Tensor>(1 );
98
- auto size = static_cast <size_t >(axes_tensor->Shape ()[0 ]);
99
- const auto * data = axes_tensor->Data <int64_t >();
100
- input_axes.reserve (size);
101
- std::transform (data, data + size, std::back_inserter (input_axes), transform_axis);
116
+ if (nullptr != axes_tensor) {
117
+ auto size = static_cast <size_t >(axes_tensor->Shape ()[0 ]);
118
+ const auto * data = axes_tensor->Data <int64_t >();
119
+ input_axes.reserve (size);
120
+ std::transform (data, data + size, std::back_inserter (input_axes), transform_axis);
121
+ }
102
122
} else {
103
123
input_axes.reserve (axes_.size ());
104
124
std::transform (axes_.begin (), axes_.end (), std::back_inserter (input_axes), transform_axis);
@@ -120,10 +140,12 @@ Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context)
120
140
std::iota (input_axes.begin (), input_axes.end (), 0 );
121
141
}
122
142
}
123
- const auto code = GetOpSpecificCode (input_tensor, input_axes. size () );
143
+ const auto code = GetOpSpecificCode (input_tensor);
124
144
// Compute output shape
125
145
std::vector<int64_t > output_shape;
146
+ bool is_input_empty = false ;
126
147
for (size_t i = 0 ; i < input_tensor->Shape ().NumDimensions (); ++i) {
148
+ is_input_empty |= input_tensor->Shape ()[i] == 0 ;
127
149
if (std::find (input_axes.begin (), input_axes.end (), i) != input_axes.end ()) {
128
150
if (keepdims_) {
129
151
output_shape.push_back (1 );
@@ -134,34 +156,68 @@ Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context)
134
156
}
135
157
TensorShape output_tensor_shape (output_shape);
136
158
int64_t output_size = output_tensor_shape.Size ();
137
- ReduceKernelProgram program (" ReduceMean" , keepdims_, noop_with_empty_axes_, input_axes, code);
138
- program.AddInput ({input_tensor, ProgramTensorMetadataDependency::TypeAndRank})
159
+ if (output_size == 0 ) {
160
+ ORT_IGNORE_RETURN_VALUE (context.Output (0 , output_tensor_shape));
161
+ return Status::OK ();
162
+ }
163
+
164
+ auto input_rank = input_tensor->Shape ().NumDimensions ();
165
+ // reduce_axes element is either 1 or 0 depending on whether the axis is reduced or not
166
+ std::vector<uint32_t > reduce_axes;
167
+ reduce_axes.resize (input_rank, 0 );
168
+ for (auto axis : input_axes) {
169
+ reduce_axes[axis] = 1 ;
170
+ }
171
+
172
+ ReduceKernelProgram program (name_, keepdims_, noop_with_empty_axes_, input_axes, code, is_input_empty);
173
+ if (!is_input_empty) {
174
+ program.AddInput ({input_tensor, ProgramTensorMetadataDependency::TypeAndRank});
175
+ }
176
+
177
+ program.CacheHint (is_input_empty)
139
178
.AddOutput ({context.Output (0 , output_shape), ProgramTensorMetadataDependency::TypeAndRank})
140
179
.SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
141
180
.AddUniformVariables ({{static_cast <uint32_t >(output_size)},
142
181
{static_cast <uint32_t >(noop_with_empty_axes_ ? 1 : 0 )},
143
- {input_axes},
144
- {static_cast <uint32_t >(input_axes.size ())}});
182
+ {reduce_axes}});
145
183
146
184
return context.RunProgram (program);
147
185
}
148
186
149
- ReduceOpSpecificCode ReduceMean::GetOpSpecificCode (const Tensor* input_tensor, size_t axes_size ) const {
187
+ ReduceOpSpecificCode ReduceMean::GetOpSpecificCode (const Tensor* input_tensor) const {
150
188
const TensorShape& input_shape = input_tensor->Shape ();
151
189
size_t input_rank = input_shape.NumDimensions ();
190
+ std::string loop_header = " var sum = f32(0);" ;
191
+ std::string loop_body = " sum += f32(current_element);" ;
152
192
std::stringstream ss;
153
193
ss << " var size: u32 = 1;\n "
154
- << " for (var i: u32 = 0; i < uniforms.axes_size; i += 1) { \n "
155
- << " let index = " << GetElementAt (" uniforms.axes" , " i" , axes_size) << " ;\n "
156
- << " size = size * " << GetElementAt (" uniforms.input_shape" , " index" , input_rank) << " ;\n "
194
+ << " for (var i: u32 = 0; i < " << input_rank << " ; i += 1) { \n "
195
+ << " let index_reduced_or_not = " << GetElementAt (" uniforms.reduce_axes" , " i" , input_rank) << " ;\n "
196
+ << " if (index_reduced_or_not == 1) { \n "
197
+ << " size = size * " << GetElementAt (" uniforms.input_shape" , " i" , input_rank) << " ;\n "
198
+ << " }\n "
157
199
<< " }\n "
158
200
<< " let output_value = output_value_t(sum / f32(size));" ;
159
- ReduceOpSpecificCode code ({" var sum = f32(0);" , " sum += f32(current_element);" , ss.str ()});
201
+ std::string loop_footer = ss.str ();
202
+ ReduceOpSpecificCode code ({loop_header, loop_body, loop_footer});
160
203
return code;
161
204
}
162
205
163
- Status ReduceMean::ComputeInternal (ComputeContext& ctx) const {
164
- return ReduceKernel<true >::ComputeInternal (ctx);
206
+ ReduceOpSpecificCode ReduceMax::GetOpSpecificCode (const Tensor* input_tensor) const {
207
+ ORT_UNUSED_PARAMETER (input_tensor);
208
+ std::string loop_header = " var max_element = first_element;" ;
209
+ std::string loop_body = " max_element = max(max_element, current_element);" ;
210
+ std::string loop_footer = " let output_value = output_value_t(max_element);" ;
211
+ ReduceOpSpecificCode code ({loop_header, loop_body, loop_footer});
212
+ return code;
213
+ }
214
+ ReduceOpSpecificCode ReduceSum::GetOpSpecificCode (const Tensor* input_tensor) const {
215
+ ORT_UNUSED_PARAMETER (input_tensor);
216
+ std::string loop_header = " var sum = f32(0);" ;
217
+ std::string loop_body = " sum += f32(current_element);" ;
218
+ std::string loop_footer = " let output_value = output_value_t(sum);" ;
219
+ ReduceOpSpecificCode code ({loop_header, loop_body, loop_footer});
220
+ return code;
165
221
}
166
222
167
223
} // namespace webgpu
0 commit comments