Skip to content

Commit 99b07eb

Browse files
author
Xiaofei Han
committed
optimize
1 parent 9438a2b commit 99b07eb

File tree

2 files changed

+111
-28
lines changed

2 files changed

+111
-28
lines changed

onnxruntime/core/providers/webgpu/math/gemm.cc

+110-27
Original file line numberDiff line numberDiff line change
@@ -36,36 +36,107 @@ WEBGPU_GEMM_VERSIONED_KERNEL(11, 12)
3636
WEBGPU_GEMM_KERNEL(13)
3737

3838
Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
39+
const uint32_t TILE_SIZE = 16;
40+
41+
// Add shared memory arrays
42+
shader.AdditionalImplementation() << "var<workgroup> tile_a: array<array<output_value_t, " << TILE_SIZE << ">, " << TILE_SIZE << ">;\n"
43+
<< "var<workgroup> tile_b: array<array<output_value_t, " << TILE_SIZE << ">, " << TILE_SIZE << ">;\n\n";
44+
3945
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
4046

41-
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
42-
<< " let m = global_idx / uniforms.N;\n"
43-
<< " let n = global_idx % uniforms.N;\n"
44-
<< " var value = output_value_t(0);\n"
45-
<< "\n";
47+
shader.MainFunctionBody() << " var value = output_value_t(0);\n\n"
48+
<< " let tile_col_start = (workgroup_id.x % uniforms.num_tile_n) * " << TILE_SIZE << "u;\n"
49+
<< " let tile_row_start = (workgroup_id.x / uniforms.num_tile_n) * " << TILE_SIZE << "u;\n";
4650

47-
// When K == 0, we don't bind A and B. Because WebGPU doesn't support binding a zero-sized buffer,
4851
if (need_handle_matmul_) {
4952
const ShaderVariableHelper& A = shader.AddInput("A", ShaderUsage::UseUniform);
5053
const ShaderVariableHelper& B = shader.AddInput("B", ShaderUsage::UseUniform);
5154

52-
shader.MainFunctionBody() << " for (var k = 0u; k < uniforms.K; k = k + 1u) {\n";
55+
shader.MainFunctionBody()
56+
<< " let num_tiles = (uniforms.K - 1u) / " << TILE_SIZE << "u + 1u;\n"
57+
<< " var k_start = 0u;\n"
58+
<< " for (var t = 0u; t < num_tiles; t = t + 1u) {\n";
5359

60+
// Fill workgroup shared memory
5461
if (transA_ && transB_) {
55-
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m")
56-
<< " * " << B.GetByOffset("n * uniforms.K + k") << ";\n";
62+
shader.MainFunctionBody() << " var col = tile_row_start + local_id.x;\n"
63+
<< " var row = k_start + local_id.y;\n"
64+
<< " if (col < uniforms.M && row < uniforms.K) {\n"
65+
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.M + col") << ";\n"
66+
<< " } else {\n"
67+
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
68+
<< " }\n\n"
69+
<< " col = k_start + local_id.x;\n"
70+
<< " row = tile_col_start + local_id.y;\n"
71+
<< " if (col < uniforms.K && row < uniforms.N) {\n"
72+
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.K + col") << ";\n"
73+
<< " } else {\n"
74+
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
75+
<< " }\n";
5776
} else if (transA_ && !transB_) {
58-
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m")
59-
<< " * " << B.GetByOffset("k * uniforms.N + n") << ";\n";
77+
shader.MainFunctionBody() << " var col = tile_row_start + local_id.x;\n"
78+
<< " var row = k_start + local_id.y;\n"
79+
<< " if (col < uniforms.M && row < uniforms.K) {\n"
80+
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.M + col") << ";\n"
81+
<< " } else {\n"
82+
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
83+
<< " }\n\n"
84+
<< " col = tile_col_start + local_id.x;\n"
85+
<< " row = k_start + local_id.y;\n"
86+
<< " if (col < uniforms.N && row < uniforms.K) {\n"
87+
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.N + col") << ";\n"
88+
<< " } else {\n"
89+
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
90+
<< " }\n";
6091
} else if (!transA_ && transB_) {
61-
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k")
62-
<< " * " << B.GetByOffset("n * uniforms.K + k") << ";\n";
92+
shader.MainFunctionBody() << " var col = k_start + local_id.x;\n"
93+
<< " var row = tile_row_start + local_id.y;\n"
94+
<< " if (col < uniforms.K && row < uniforms.M) {\n"
95+
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.K + col") << ";\n"
96+
<< " } else {\n"
97+
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
98+
<< " }\n\n"
99+
<< " col = k_start + local_id.x;\n"
100+
<< " row = tile_col_start + local_id.y;\n"
101+
<< " if (col < uniforms.K && row < uniforms.N) {\n"
102+
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.K + col") << ";\n"
103+
<< " } else {\n"
104+
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
105+
<< " }\n";
63106
} else {
64-
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k")
65-
<< " * " << B.GetByOffset("k * uniforms.N + n") << ";\n";
107+
shader.MainFunctionBody() << " var col = k_start + local_id.x;\n"
108+
<< " var row = tile_row_start + local_id.y;\n"
109+
<< " if (col < uniforms.K && row < uniforms.M) {\n"
110+
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.K + col") << ";\n"
111+
<< " } else {\n"
112+
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
113+
<< " }\n\n"
114+
<< " col = tile_col_start + local_id.x;\n"
115+
<< " row = k_start + local_id.y;\n"
116+
<< " if (col < uniforms.N && row < uniforms.K) {\n"
117+
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.N + col") << ";\n"
118+
<< " } else {\n"
119+
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
120+
<< " }\n";
66121
}
67-
shader.MainFunctionBody() << " }\n"
68-
<< "\n";
122+
123+
shader.MainFunctionBody() << " k_start = k_start + " << TILE_SIZE << "u;\n"
124+
<< " workgroupBarrier();\n\n"
125+
<< " for (var k = 0u; k < " << TILE_SIZE << "u; k = k + 1u) {\n";
126+
127+
if (transA_ && transB_) {
128+
shader.MainFunctionBody() << " value = value + tile_a[k][local_id.y] * tile_b[local_id.x][k];\n";
129+
} else if (transA_ && !transB_) {
130+
shader.MainFunctionBody() << " value = value + tile_a[k][local_id.y] * tile_b[k][local_id.x];\n";
131+
} else if (!transA_ && transB_) {
132+
shader.MainFunctionBody() << " value = value + tile_a[local_id.y][k] * tile_b[local_id.x][k];\n";
133+
} else {
134+
shader.MainFunctionBody() << " value = value + tile_a[local_id.y][k] * tile_b[k][local_id.x];\n";
135+
}
136+
137+
shader.MainFunctionBody() << " }\n"
138+
<< " workgroupBarrier();\n"
139+
<< " }\n\n";
69140
}
70141

71142
// Calculate Alpha
@@ -76,11 +147,19 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
76147
// Calculate Bias
77148
if (need_handle_bias_) {
78149
const ShaderVariableHelper& C = shader.AddInput("C", ShaderUsage::UseUniform);
79-
shader.MainFunctionBody() << " value = value + output_value_t(uniforms.beta) * "
150+
shader.MainFunctionBody() << " let m = tile_row_start + local_id.y;\n"
151+
<< " let n = tile_col_start + local_id.x;\n"
152+
<< " value = value + output_value_t(uniforms.beta) * "
80153
<< C.GetByOffset(C.BroadcastedIndicesToOffset("vec2(m, n)", output)) << ";\n";
154+
} else {
155+
shader.MainFunctionBody() << " let m = tile_row_start + local_id.y;\n"
156+
<< " let n = tile_col_start + local_id.x;\n";
81157
}
82158

83-
shader.MainFunctionBody() << output.SetByOffset("global_idx", "value") << "\n";
159+
// Write output
160+
shader.MainFunctionBody() << " if (m < uniforms.M && n < uniforms.N) {\n"
161+
<< " " << output.SetByOffset("m * uniforms.N + n", "value") << "\n"
162+
<< " }\n";
84163

85164
return Status::OK();
86165
}
@@ -132,16 +211,20 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
132211
program.AddInput({C, ProgramTensorMetadataDependency::Rank});
133212
}
134213

214+
const uint32_t TILE_SIZE = 16;
215+
const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE;
216+
const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE;
217+
135218
program.AddOutputs({{Y, ProgramTensorMetadataDependency::Type}})
136-
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
137-
.SetWorkgroupSize(WORKGROUP_SIZE)
219+
.SetDispatchGroupSize(num_tile_n * num_tile_m)
220+
.SetWorkgroupSize(TILE_SIZE, TILE_SIZE)
138221
.AddUniformVariables({
139-
{static_cast<uint32_t>(output_size)}, // output_size
140-
{static_cast<uint32_t>(M)}, // M
141-
{static_cast<uint32_t>(N)}, // N
142-
{static_cast<uint32_t>(K)}, // K
143-
{alpha_}, // alpha
144-
{beta_} // beta
222+
{static_cast<uint32_t>(num_tile_n)}, // num_tile_n
223+
{static_cast<uint32_t>(M)}, // M
224+
{static_cast<uint32_t>(N)}, // N
225+
{static_cast<uint32_t>(K)}, // K
226+
{alpha_}, // alpha
227+
{beta_} // beta
145228
});
146229

147230
return context.RunProgram(program);

onnxruntime/core/providers/webgpu/math/gemm.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class GemmProgram final : public Program<GemmProgram> {
2323
Status GenerateShaderCode(ShaderHelper& sh) const override;
2424

2525
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
26-
{"output_size", ProgramUniformVariableDataType::Uint32},
26+
{"num_tile_n", ProgramUniformVariableDataType::Uint32},
2727
{"M", ProgramUniformVariableDataType::Uint32},
2828
{"N", ProgramUniformVariableDataType::Uint32},
2929
{"K", ProgramUniformVariableDataType::Uint32},

0 commit comments

Comments
 (0)