@@ -34,11 +34,28 @@ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12);
3434REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMean, 13 , 17 );
3535REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceMean, 18 );
3636
37+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 1 , 10 );
38+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 11 , 11 );
39+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 12 , 12 );
40+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 13 , 17 );
41+ REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceMax, 18 );
42+
43+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceSum, 1 , 10 );
44+ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceSum, 11 , 12 );
45+ REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceSum, 13 );
46+
3747Status ReduceKernelProgram::GenerateShaderCode (ShaderHelper& shader) const {
38- const auto & input = shader.AddInput (" input" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
3948 const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
49+ if (is_input_empty_) {
50+ shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
51+ << code_[0 ]
52+ << code_[2 ]
53+ << output.SetByOffset (" global_idx" , " output_value" );
54+ return Status::OK ();
55+ }
56+ const auto & input = shader.AddInput (" input" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
4057 bool reduce_on_all_axes = no_op_with_empty_axes_ == false && axes_.empty ();
41- std::string loop_header = code_[0 ];
58+ std::string loop_header = code_[0 ]. find ( " first_element " ) == std::string::npos ? code_[ 0 ] : " let first_element = " + input. GetByIndices ( " input_indices " ) + " ; \n " + code_[ 0 ] + " \n " ;
4259 std::string loop_body = " let current_element: input_value_t = " + input.GetByIndices (" input_indices" ) + " ;\n " + code_[1 ];
4360 std::string loop_footer = code_[2 ];
4461 const auto input_rank = input.Rank ();
@@ -56,10 +73,10 @@ Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const {
5673 loop_body = ss.str ();
5774 } else {
5875 std::stringstream ss;
59- ss << loop_header << " \n " ;
6076 std::string index = " i" + std::to_string (i);
6177 ss << " let " << index << " = " << output.IndicesGet (" output_indices" , l) << " ;\n " ;
6278 ss << input.IndicesSet (" input_indices" , i, index) << " ;\n " ;
79+ ss << loop_header << " \n " ;
6380 loop_header = ss.str ();
6481 l++;
6582 }
@@ -80,6 +97,7 @@ Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const {
8097template <bool allow_multi_axes>
8198Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context) const {
8299 const auto * input_tensor = context.Input (0 );
100+ ORT_RETURN_IF_ERROR (CheckInput (input_tensor));
83101 InlinedVector<uint32_t > input_axes;
84102 auto rank = input_tensor->Shape ().NumDimensions ();
85103 auto transform_axis = [rank](int64_t axis) {
@@ -95,10 +113,12 @@ Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context)
95113 if (context.InputCount () > 1 ) {
96114 ORT_ENFORCE (axes_.empty (), " Axes attribute may not be specified when axes input is also provided." );
97115 const Tensor* axes_tensor = context.Input <Tensor>(1 );
98- auto size = static_cast <size_t >(axes_tensor->Shape ()[0 ]);
99- const auto * data = axes_tensor->Data <int64_t >();
100- input_axes.reserve (size);
101- std::transform (data, data + size, std::back_inserter (input_axes), transform_axis);
116+ if (nullptr != axes_tensor) {
117+ auto size = static_cast <size_t >(axes_tensor->Shape ()[0 ]);
118+ const auto * data = axes_tensor->Data <int64_t >();
119+ input_axes.reserve (size);
120+ std::transform (data, data + size, std::back_inserter (input_axes), transform_axis);
121+ }
102122 } else {
103123 input_axes.reserve (axes_.size ());
104124 std::transform (axes_.begin (), axes_.end (), std::back_inserter (input_axes), transform_axis);
@@ -120,10 +140,12 @@ Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context)
120140 std::iota (input_axes.begin (), input_axes.end (), 0 );
121141 }
122142 }
123- const auto code = GetOpSpecificCode (input_tensor, input_axes. size () );
143+ const auto code = GetOpSpecificCode (input_tensor);
124144 // Compute output shape
125145 std::vector<int64_t > output_shape;
146+ bool is_input_empty = false ;
126147 for (size_t i = 0 ; i < input_tensor->Shape ().NumDimensions (); ++i) {
148+ is_input_empty |= input_tensor->Shape ()[i] == 0 ;
127149 if (std::find (input_axes.begin (), input_axes.end (), i) != input_axes.end ()) {
128150 if (keepdims_) {
129151 output_shape.push_back (1 );
@@ -134,34 +156,68 @@ Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context)
134156 }
135157 TensorShape output_tensor_shape (output_shape);
136158 int64_t output_size = output_tensor_shape.Size ();
137- ReduceKernelProgram program (" ReduceMean" , keepdims_, noop_with_empty_axes_, input_axes, code);
138- program.AddInput ({input_tensor, ProgramTensorMetadataDependency::TypeAndRank})
159+ if (output_size == 0 ) {
160+ ORT_IGNORE_RETURN_VALUE (context.Output (0 , output_tensor_shape));
161+ return Status::OK ();
162+ }
163+
164+ auto input_rank = input_tensor->Shape ().NumDimensions ();
165+ // reduce_axes element is either 1 or 0 depending on whether the axis is reduced or not
166+ std::vector<uint32_t > reduce_axes;
167+ reduce_axes.resize (input_rank, 0 );
168+ for (auto axis : input_axes) {
169+ reduce_axes[axis] = 1 ;
170+ }
171+
172+ ReduceKernelProgram program (name_, keepdims_, noop_with_empty_axes_, input_axes, code, is_input_empty);
173+ if (!is_input_empty) {
174+ program.AddInput ({input_tensor, ProgramTensorMetadataDependency::TypeAndRank});
175+ }
176+
177+ program.CacheHint (is_input_empty)
139178 .AddOutput ({context.Output (0 , output_shape), ProgramTensorMetadataDependency::TypeAndRank})
140179 .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
141180 .AddUniformVariables ({{static_cast <uint32_t >(output_size)},
142181 {static_cast <uint32_t >(noop_with_empty_axes_ ? 1 : 0 )},
143- {input_axes},
144- {static_cast <uint32_t >(input_axes.size ())}});
182+ {reduce_axes}});
145183
146184 return context.RunProgram (program);
147185}
148186
149- ReduceOpSpecificCode ReduceMean::GetOpSpecificCode (const Tensor* input_tensor, size_t axes_size ) const {
187+ ReduceOpSpecificCode ReduceMean::GetOpSpecificCode (const Tensor* input_tensor) const {
150188 const TensorShape& input_shape = input_tensor->Shape ();
151189 size_t input_rank = input_shape.NumDimensions ();
190+ std::string loop_header = " var sum = f32(0);" ;
191+ std::string loop_body = " sum += f32(current_element);" ;
152192 std::stringstream ss;
153193 ss << " var size: u32 = 1;\n "
154- << " for (var i: u32 = 0; i < uniforms.axes_size; i += 1) { \n "
155- << " let index = " << GetElementAt (" uniforms.axes" , " i" , axes_size) << " ;\n "
156- << " size = size * " << GetElementAt (" uniforms.input_shape" , " index" , input_rank) << " ;\n "
194+ << " for (var i: u32 = 0; i < " << input_rank << " ; i += 1) { \n "
195+ << " let index_reduced_or_not = " << GetElementAt (" uniforms.reduce_axes" , " i" , input_rank) << " ;\n "
196+ << " if (index_reduced_or_not == 1) { \n "
197+ << " size = size * " << GetElementAt (" uniforms.input_shape" , " i" , input_rank) << " ;\n "
198+ << " }\n "
157199 << " }\n "
158200 << " let output_value = output_value_t(sum / f32(size));" ;
159- ReduceOpSpecificCode code ({" var sum = f32(0);" , " sum += f32(current_element);" , ss.str ()});
201+ std::string loop_footer = ss.str ();
202+ ReduceOpSpecificCode code ({loop_header, loop_body, loop_footer});
160203 return code;
161204}
162205
163- Status ReduceMean::ComputeInternal (ComputeContext& ctx) const {
164- return ReduceKernel<true >::ComputeInternal (ctx);
206+ ReduceOpSpecificCode ReduceMax::GetOpSpecificCode (const Tensor* input_tensor) const {
207+ ORT_UNUSED_PARAMETER (input_tensor);
208+ std::string loop_header = " var max_element = first_element;" ;
209+ std::string loop_body = " max_element = max(max_element, current_element);" ;
210+ std::string loop_footer = " let output_value = output_value_t(max_element);" ;
211+ ReduceOpSpecificCode code ({loop_header, loop_body, loop_footer});
212+ return code;
213+ }
214+ ReduceOpSpecificCode ReduceSum::GetOpSpecificCode (const Tensor* input_tensor) const {
215+ ORT_UNUSED_PARAMETER (input_tensor);
216+ std::string loop_header = " var sum = f32(0);" ;
217+ std::string loop_body = " sum += f32(current_element);" ;
218+ std::string loop_footer = " let output_value = output_value_t(sum);" ;
219+ ReduceOpSpecificCode code ({loop_header, loop_body, loop_footer});
220+ return code;
165221}
166222
167223} // namespace webgpu
0 commit comments