Skip to content

Commit a2a7d5f

Browse files
author
Xiaofei Han
committed
implement
1 parent 325ee30 commit a2a7d5f

File tree

4 files changed

+218
-6
lines changed

4 files changed

+218
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/math/gemm.h"
5+
6+
#include "core/providers/webgpu/shader_helper.h"
7+
#include "core/providers/webgpu/webgpu_supported_types.h"
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
12+
#define WEBGPU_GEMM_VERSIONED_KERNEL(start, end) \
13+
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
14+
Gemm, \
15+
kOnnxDomain, \
16+
start, \
17+
end, \
18+
kWebGpuExecutionProvider, \
19+
(*KernelDefBuilder::Create()) \
20+
.TypeConstraint("T", WebGpuSupportedNumberTypes()), \
21+
Gemm);
22+
23+
#define WEBGPU_GEMM_KERNEL(version) \
24+
ONNX_OPERATOR_KERNEL_EX( \
25+
Gemm, \
26+
kOnnxDomain, \
27+
version, \
28+
kWebGpuExecutionProvider, \
29+
(*KernelDefBuilder::Create()) \
30+
.TypeConstraint("T", WebGpuSupportedNumberTypes()), \
31+
Gemm);
32+
33+
WEBGPU_GEMM_VERSIONED_KERNEL(7, 8)
34+
WEBGPU_GEMM_VERSIONED_KERNEL(9, 10)
35+
WEBGPU_GEMM_VERSIONED_KERNEL(11, 12)
36+
WEBGPU_GEMM_KERNEL(13)
37+
38+
Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
39+
const ShaderVariableHelper& A = shader.AddInput("A", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
40+
const ShaderVariableHelper& B = shader.AddInput("B", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
41+
42+
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
43+
44+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
45+
<< " let m = global_idx / uniforms.N;\n"
46+
<< " let n = global_idx % uniforms.N;\n"
47+
<< " var value = A_value_t(0);\n"
48+
<< "\n"
49+
<< " for (var k = 0u; k < uniforms.K; k = k + 1u) {\n";
50+
51+
if (transA_ && transB_) {
52+
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m")
53+
<< " * " << B.GetByOffset("n * uniforms.K + k") << ";\n";
54+
} else if (transA_ && !transB_) {
55+
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m")
56+
<< " * " << B.GetByOffset("k * uniforms.N + n") << ";\n";
57+
} else if (!transA_ && transB_) {
58+
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k")
59+
<< " * " << B.GetByOffset("n * uniforms.K + k") << ";\n";
60+
} else {
61+
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k")
62+
<< " * " << B.GetByOffset("k * uniforms.N + n") << ";\n";
63+
}
64+
shader.MainFunctionBody() << " }\n"
65+
<< "\n";
66+
// Calculate Alpha
67+
if (alpha_) {
68+
shader.MainFunctionBody() << " value = value * A_value_t(uniforms.alpha);\n";
69+
}
70+
71+
// Calculate Bias
72+
if (need_handle_bias_) {
73+
const ShaderVariableHelper& C = shader.AddInput("C", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
74+
shader.MainFunctionBody() << " value = value + A_value_t(uniforms.beta) * "
75+
<< C.GetByOffset(C.BroadcastedIndicesToOffset("vec2(m, n)", output)) << ";\n";
76+
}
77+
78+
shader.MainFunctionBody() << output.SetByOffset("global_idx", "value") << "\n";
79+
80+
return Status::OK();
81+
}
82+
83+
Status Gemm::ComputeInternal(ComputeContext& context) const {
84+
const auto* A = context.Input<Tensor>(0);
85+
const auto* B = context.Input<Tensor>(1);
86+
const auto* C = context.Input<Tensor>(2);
87+
88+
if (A == nullptr || B == nullptr) {
89+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Gemm requires input tensors A and B.");
90+
}
91+
92+
const auto& A_shape = A->Shape();
93+
const auto& B_shape = B->Shape();
94+
95+
if (A_shape.NumDimensions() != 2 || B_shape.NumDimensions() != 2) {
96+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input tensors A and B must be 2 dimensional.");
97+
}
98+
99+
int64_t M = transA_ ? A_shape[1] : A_shape[0];
100+
int64_t K = transA_ ? A_shape[0] : A_shape[1];
101+
int64_t N = transB_ ? B_shape[0] : B_shape[1];
102+
103+
if ((transA_ ? A_shape[0] : A_shape[1]) != (transB_ ? B_shape[1] : B_shape[0])) {
104+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inner dimensions of A and B must match.");
105+
}
106+
107+
std::vector<int64_t> output_dims{M, N};
108+
auto* Y = context.Output(0, output_dims);
109+
int64_t output_size = Y->Shape().Size();
110+
111+
if (output_size == 0) {
112+
return Status::OK();
113+
}
114+
115+
constexpr size_t TILE_SIZE = 16;
116+
int64_t num_tiles_m = (M + TILE_SIZE - 1) / TILE_SIZE;
117+
int64_t num_tiles_n = (N + TILE_SIZE - 1) / TILE_SIZE;
118+
int64_t dispatch_size = num_tiles_m * num_tiles_n;
119+
120+
GemmProgram program{transA_, transB_, alpha_, beta_, C && beta_};
121+
program.AddInputs({{A, ProgramTensorMetadataDependency::Type},
122+
{B, ProgramTensorMetadataDependency::Type}});
123+
124+
if (C && beta_) {
125+
program.AddInput({C, ProgramTensorMetadataDependency::Rank});
126+
}
127+
128+
program.AddOutputs({Y})
129+
.SetDispatchGroupSize(dispatch_size)
130+
.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1)
131+
.AddUniformVariables({
132+
{static_cast<uint32_t>(output_size)}, // output_size
133+
{static_cast<uint32_t>(M)}, // M
134+
{static_cast<uint32_t>(N)}, // N
135+
{static_cast<uint32_t>(K)}, // K
136+
{alpha_}, // alpha
137+
{beta_} // beta
138+
});
139+
140+
return context.RunProgram(program);
141+
}
142+
143+
} // namespace webgpu
144+
} // namespace onnxruntime
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
#include "core/providers/webgpu/shader_helper.h"
8+
#include "core/providers/webgpu/program.h"
9+
10+
namespace onnxruntime {
11+
namespace webgpu {
12+
13+
class GemmProgram final : public Program<GemmProgram> {
14+
public:
15+
GemmProgram(bool transA, bool transB, float alpha, float beta, bool need_handle_bias)
16+
: Program{"Gemm"},
17+
transA_{transA},
18+
transB_{transB},
19+
alpha_{alpha},
20+
beta_{beta},
21+
need_handle_bias_{need_handle_bias} {}
22+
23+
Status GenerateShaderCode(ShaderHelper& sh) const override;
24+
25+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
26+
{"output_size", ProgramUniformVariableDataType::Uint32},
27+
{"M", ProgramUniformVariableDataType::Uint32},
28+
{"N", ProgramUniformVariableDataType::Uint32},
29+
{"K", ProgramUniformVariableDataType::Uint32},
30+
{"alpha", ProgramUniformVariableDataType::Float32},
31+
{"beta", ProgramUniformVariableDataType::Float32});
32+
33+
private:
34+
bool transA_;
35+
bool transB_;
36+
float alpha_;
37+
float beta_;
38+
bool need_handle_bias_;
39+
};
40+
41+
class Gemm final : public WebGpuKernel {
42+
public:
43+
Gemm(const OpKernelInfo& info) : WebGpuKernel(info) {
44+
int64_t transA_temp;
45+
info.GetAttrOrDefault("transA", &transA_temp, static_cast<int64_t>(0));
46+
transA_ = transA_temp != 0;
47+
48+
int64_t transB_temp;
49+
info.GetAttrOrDefault("transB", &transB_temp, static_cast<int64_t>(0));
50+
transB_ = transB_temp != 0;
51+
52+
info.GetAttrOrDefault("alpha", &alpha_, 1.0f);
53+
info.GetAttrOrDefault("beta", &beta_, 1.0f);
54+
}
55+
56+
Status ComputeInternal(ComputeContext& context) const override;
57+
58+
private:
59+
bool transA_;
60+
bool transB_;
61+
float alpha_;
62+
float beta_;
63+
};
64+
65+
} // namespace webgpu
66+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
610610
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalMaxPool)>,
611611
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool)>,
612612

613-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Gemm)>,
614-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Gemm)>,
615-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gemm)>,
616-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gemm)>,
613+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Gemm)>,
614+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Gemm)>,
615+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gemm)>,
616+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gemm)>,
617617
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
618618
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MatMul)>,
619619

onnxruntime/test/providers/cpu/math/gemm_test.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -761,9 +761,10 @@ TYPED_TEST(GemmOpTypedTests, ZeroKWithBias) {
761761
test.AddInput<TypeParam>("C", {4}, std::vector<TypeParam>(4, static_cast<TypeParam>(1.0f)));
762762
test.AddOutput<TypeParam>("Y", {4, 4}, std::vector<TypeParam>(16, static_cast<TypeParam>(1.0f)));
763763

764+
// WebGPU EP doesn't support zero buffer.
764765
test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
765766
kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
766-
kOpenVINOExecutionProvider})
767+
kOpenVINOExecutionProvider, kWebGpuExecutionProvider})
767768
.Config(run_with_tunable_op)
768769
.RunWithConfig();
769770
}
@@ -780,9 +781,10 @@ TYPED_TEST(GemmOpTypedTests, ZeroKWithNoBias) {
780781
test.AddInput<TypeParam>("B", {0, 4}, {});
781782
test.AddOutput<TypeParam>("Y", {4, 4}, std::vector<TypeParam>(16, static_cast<TypeParam>(0.0f)));
782783

784+
// WebGPU EP doesn't support zero buffer.
783785
test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
784786
kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
785-
kOpenVINOExecutionProvider})
787+
kOpenVINOExecutionProvider, kWebGpuExecutionProvider})
786788
.Config(run_with_tunable_op)
787789
.RunWithConfig();
788790
}

0 commit comments

Comments
 (0)