Skip to content

Commit 9891eb3

Browse files
authored
[WebGPU EP] Implements Gelu, BiasSplitGelu, and QuickGelu (#23981)
Increases WebGPU operator coverage
1 parent 16d6f39 commit 9891eb3

File tree

8 files changed

+250
-36
lines changed

8 files changed

+250
-36
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/shader_helper.h"
5+
#include "core/providers/webgpu/webgpu_supported_types.h"
6+
#include "contrib_ops/webgpu/bert/bias_split_gelu.h"
7+
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
8+
#include "core/providers/webgpu/webgpu_utils.h"
9+
#include "core/providers/webgpu/math/unary_elementwise_ops.h"
10+
11+
namespace onnxruntime {
12+
namespace contrib {
13+
namespace webgpu {
14+
15+
ONNX_OPERATOR_KERNEL_EX(
16+
BiasSplitGelu,
17+
kMSDomain,
18+
1,
19+
kWebGpuExecutionProvider,
20+
(*KernelDefBuilder::Create())
21+
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
22+
BiasSplitGelu);
23+
24+
Status BiasSplitGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
25+
const ShaderVariableHelper& input = shader.AddInput("input");
26+
const ShaderVariableHelper& bias = shader.AddInput("bias");
27+
const ShaderVariableHelper& output = shader.AddOutput("output");
28+
29+
shader.AdditionalImplementation() << ErfImpl;
30+
31+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
32+
<< "const M_SQRT2: f32 = sqrt(2.0);\n"
33+
<< "const halfChannels = uniforms.channels / 2u;\n"
34+
<< "let biasIdx = global_idx % halfChannels;\n"
35+
<< "let batchIndex = global_idx / halfChannels;\n"
36+
<< "let inputOffset = biasIdx + batchIndex * halfChannels * 2;\n"
37+
<< "let valueLeft = " << input.GetByOffset("inputOffset") << " + " << bias.GetByOffset("biasIdx") << ";\n"
38+
<< "let valueRight = " << input.GetByOffset("inputOffset + halfChannels") << " + " << bias.GetByOffset("biasIdx + halfChannels") << ";\n"
39+
<< "let geluRight = valueRight * 0.5 * (erf_v(valueRight / M_SQRT2) + 1);\n"
40+
<< output.SetByOffset("global_idx", "valueLeft * geluRight");
41+
42+
return Status::OK();
43+
}
44+
45+
Status BiasSplitGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
46+
const auto* input = context.Input(0);
47+
const auto* bias = context.Input(1);
48+
49+
TensorShape input_shape = input->Shape();
50+
51+
if (input_shape.NumDimensions() != 3) {
52+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasSplitGelu input should have 3 dimensions.");
53+
}
54+
55+
int64_t channels = input_shape[2];
56+
int64_t components = GetMaxComponents(channels);
57+
channels /= components;
58+
input_shape[2] = channels / 2; // for output shape calculation (N,S,D) -> (N,S,D/2)
59+
60+
TensorShape bias_shape = bias->Shape();
61+
if (bias_shape.NumDimensions() != 1 || bias_shape[0] != channels) {
62+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasSplitGelu bias should have 1 dimension with size equal to the number of channels.");
63+
}
64+
65+
auto* output = context.Output(0, input_shape);
66+
int64_t output_size = output->Shape().Size() / components;
67+
68+
BiasSplitGeluProgram program{};
69+
program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
70+
{bias}})
71+
.AddOutput({output})
72+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
73+
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
74+
{static_cast<uint32_t>(channels)}});
75+
return context.RunProgram(program);
76+
}
77+
78+
} // namespace webgpu
79+
} // namespace contrib
80+
} // namespace onnxruntime
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/program.h"
7+
#include "core/providers/webgpu/webgpu_kernel.h"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
namespace webgpu {
12+
13+
using namespace onnxruntime::webgpu;
14+
using onnxruntime::webgpu::ComputeContext;
15+
16+
class BiasSplitGeluProgram final : public Program<BiasSplitGeluProgram> {
17+
public:
18+
BiasSplitGeluProgram() : Program{"BiasSplitGelu"} {}
19+
Status GenerateShaderCode(ShaderHelper& sh) const override;
20+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
21+
{"channels", ProgramUniformVariableDataType::Uint32});
22+
};
23+
24+
class BiasSplitGelu final : public WebGpuKernel {
25+
public:
26+
BiasSplitGelu(const OpKernelInfo& info) : WebGpuKernel(info) {}
27+
Status ComputeInternal(ComputeContext& context) const override;
28+
};
29+
30+
} // namespace webgpu
31+
} // namespace contrib
32+
} // namespace onnxruntime
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/shader_helper.h"
5+
#include "core/providers/webgpu/webgpu_supported_types.h"
6+
#include "core/providers/webgpu/math/unary_elementwise_ops.h" // contains Gelu definition
7+
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
namespace webgpu {
12+
13+
using namespace onnxruntime::webgpu;
14+
using onnxruntime::webgpu::ComputeContext;
15+
16+
ONNX_OPERATOR_KERNEL_EX(
17+
Gelu,
18+
kMSDomain,
19+
1,
20+
kWebGpuExecutionProvider,
21+
(*KernelDefBuilder::Create())
22+
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
23+
Gelu);
24+
25+
} // namespace webgpu
26+
} // namespace contrib
27+
} // namespace onnxruntime
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/shader_helper.h"
5+
#include "core/providers/webgpu/webgpu_supported_types.h"
6+
#include "core/providers/webgpu/math/unary_elementwise_ops.h" // contained Gelu definition
7+
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
namespace webgpu {
12+
13+
using namespace onnxruntime::webgpu;
14+
using onnxruntime::webgpu::ComputeContext;
15+
16+
ONNX_OPERATOR_KERNEL_EX(
17+
QuickGelu,
18+
kMSDomain,
19+
1,
20+
kWebGpuExecutionProvider,
21+
(*KernelDefBuilder::Create())
22+
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
23+
QuickGelu);
24+
25+
} // namespace webgpu
26+
} // namespace contrib
27+
} // namespace onnxruntime

onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
3838
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
3939
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,
4040
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
41-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
41+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
4242
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
4343
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
44-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
44+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
4545
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
4646
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
4747
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
48-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
48+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
4949
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
5050
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization)>,
5151
// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it

onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -256,26 +256,6 @@ WEBGPU_CLIP_KERNEL(MLFloat16)
256256
// activation
257257
//
258258

259-
class LinearUnit : public UnaryElementwise {
260-
public:
261-
LinearUnit(const OpKernelInfo& info,
262-
const std::string& kernel_name,
263-
const std::string& expression,
264-
const std::string& additional_impl,
265-
float default_alpha)
266-
: UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} {
267-
info.GetAttrOrDefault("alpha", &alpha_, default_alpha);
268-
}
269-
270-
Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override {
271-
program.AddUniformVariables({alpha_});
272-
return Status::OK();
273-
}
274-
275-
protected:
276-
float alpha_;
277-
};
278-
279259
#define WEBGPU_LU_IMPL(OP_TYPE, ...) \
280260
class OP_TYPE final : public LinearUnit { \
281261
public: \
@@ -285,17 +265,17 @@ class LinearUnit : public UnaryElementwise {
285265
WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0)
286266
WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes())
287267

288-
class Gelu : public UnaryElementwise {
289-
public:
290-
Gelu(const OpKernelInfo& info)
291-
: UnaryElementwise{info,
292-
"Gelu",
293-
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr,
294-
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? TanhImpl : ErfImpl,
295-
ShaderUsage::UseValueTypeAlias} {
296-
cache_hint = info.GetAttrOrDefault<std::string>("approximate", "none");
297-
}
298-
};
268+
Gelu::Gelu(const OpKernelInfo& info)
269+
: UnaryElementwise{info,
270+
"Gelu",
271+
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr,
272+
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? TanhImpl : ErfImpl,
273+
ShaderUsage::UseValueTypeAlias} {
274+
cache_hint = info.GetAttrOrDefault<std::string>("approximate", "none");
275+
}
276+
277+
QuickGelu::QuickGelu(const OpKernelInfo& info)
278+
: LinearUnit{info, "QuickGelu", "quick_gelu_v(a)", QuickGeluImpl, 1.702f} {}
299279

300280
WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes())
301281

@@ -312,4 +292,4 @@ WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4<x_element_t>(0), a, a > vec4<x_elem
312292
WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes())
313293

314294
} // namespace webgpu
315-
} // namespace onnxruntime
295+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,36 @@ class UnaryElementwise : public WebGpuKernel {
6060
ShaderUsage additional_usage_;
6161
};
6262

63+
class Gelu : public UnaryElementwise {
64+
public:
65+
Gelu(const OpKernelInfo& info);
66+
};
67+
68+
class LinearUnit : public UnaryElementwise {
69+
public:
70+
LinearUnit(const OpKernelInfo& info,
71+
const std::string& kernel_name,
72+
const std::string& expression,
73+
const std::string& additional_impl,
74+
float default_alpha)
75+
: UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} {
76+
info.GetAttrOrDefault("alpha", &alpha_, default_alpha);
77+
}
78+
79+
Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override {
80+
program.AddUniformVariables({alpha_});
81+
return Status::OK();
82+
}
83+
84+
protected:
85+
float alpha_;
86+
};
87+
88+
class QuickGelu : public LinearUnit {
89+
public:
90+
QuickGelu(const OpKernelInfo& info);
91+
};
92+
6393
constexpr const char ErfImpl[] = R"(
6494
const r0 = 0.3275911;
6595
const r1 = 0.254829592;
@@ -104,11 +134,29 @@ fn elu_v(v: vec4<x_element_t>) -> vec4<x_element_t> {
104134
}
105135
)";
106136

137+
constexpr const char QuickGeluImpl[] = R"(
138+
fn quick_gelu_v(a: vec4<x_element_t>) -> vec4<x_element_t> {
139+
let one = 1.0;
140+
let zero = 0.0;
141+
let alpha_vec = vec4<x_element_t>(uniforms.attr);
142+
let v = a * alpha_vec;
143+
var x1 : vec4<x_element_t>;
144+
for (var i = 0; i < 4; i = i + 1) {
145+
if (v[i] >= zero) {
146+
x1[i] = one / (one + exp(-v[i]));
147+
} else {
148+
x1[i] = one - one / (one + exp(v[i]));
149+
}
150+
}
151+
return a * x1;
152+
}
153+
)";
154+
107155
// default GELU expression, depending on ErfImpl
108156
constexpr const char GeluExpr[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))";
109157

110158
// fast GELU expression, depending on TanhImpl
111159
constexpr const char FastGeluExpr[] = "a * (0.5 + 0.5 * tanh_v(a * (0.035677408136300125 * a * a + 0.7978845608028654)))";
112160

113161
} // namespace webgpu
114-
} // namespace onnxruntime
162+
} // namespace onnxruntime
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Licensed under the MIT License.
2+
3+
#pragma once
4+
5+
#include <cstdint>
6+
7+
namespace onnxruntime {
8+
namespace webgpu {
9+
10+
inline int64_t GetMaxComponents(int64_t size) {
11+
if (size % 4 == 0) {
12+
return 4;
13+
} else if (size % 2 == 0) {
14+
return 2;
15+
}
16+
return 1;
17+
}
18+
19+
} // namespace webgpu
20+
} // namespace onnxruntime

0 commit comments

Comments
 (0)