@@ -36,36 +36,107 @@ WEBGPU_GEMM_VERSIONED_KERNEL(11, 12)
36
36
WEBGPU_GEMM_KERNEL(13 )
37
37
38
38
Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
39
+ const uint32_t TILE_SIZE = 16 ;
40
+
41
+ // Add shared memory arrays
42
+ shader.AdditionalImplementation () << " var<workgroup> tile_a: array<array<output_value_t, " << TILE_SIZE << " >, " << TILE_SIZE << " >;\n "
43
+ << " var<workgroup> tile_b: array<array<output_value_t, " << TILE_SIZE << " >, " << TILE_SIZE << " >;\n\n " ;
44
+
39
45
const ShaderVariableHelper& output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
40
46
41
- shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
42
- << " let m = global_idx / uniforms.N;\n "
43
- << " let n = global_idx % uniforms.N;\n "
44
- << " var value = output_value_t(0);\n "
45
- << " \n " ;
47
+ shader.MainFunctionBody () << " var value = output_value_t(0);\n\n "
48
+ << " let tile_col_start = (workgroup_id.x % uniforms.num_tile_n) * " << TILE_SIZE << " u;\n "
49
+ << " let tile_row_start = (workgroup_id.x / uniforms.num_tile_n) * " << TILE_SIZE << " u;\n " ;
46
50
47
- // When K == 0, we don't bind A and B. Because WebGPU doesn't support binding a zero-sized buffer,
48
51
if (need_handle_matmul_) {
49
52
const ShaderVariableHelper& A = shader.AddInput (" A" , ShaderUsage::UseUniform);
50
53
const ShaderVariableHelper& B = shader.AddInput (" B" , ShaderUsage::UseUniform);
51
54
52
- shader.MainFunctionBody () << " for (var k = 0u; k < uniforms.K; k = k + 1u) {\n " ;
55
+ shader.MainFunctionBody ()
56
+ << " let num_tiles = (uniforms.K - 1u) / " << TILE_SIZE << " u + 1u;\n "
57
+ << " var k_start = 0u;\n "
58
+ << " for (var t = 0u; t < num_tiles; t = t + 1u) {\n " ;
53
59
60
+ // Fill workgroup shared memory
54
61
if (transA_ && transB_) {
55
- shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" k * uniforms.M + m" )
56
- << " * " << B.GetByOffset (" n * uniforms.K + k" ) << " ;\n " ;
62
+ shader.MainFunctionBody () << " var col = tile_row_start + local_id.x;\n "
63
+ << " var row = k_start + local_id.y;\n "
64
+ << " if (col < uniforms.M && row < uniforms.K) {\n "
65
+ << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.M + col" ) << " ;\n "
66
+ << " } else {\n "
67
+ << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
68
+ << " }\n\n "
69
+ << " col = k_start + local_id.x;\n "
70
+ << " row = tile_col_start + local_id.y;\n "
71
+ << " if (col < uniforms.K && row < uniforms.N) {\n "
72
+ << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
73
+ << " } else {\n "
74
+ << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
75
+ << " }\n " ;
57
76
} else if (transA_ && !transB_) {
58
- shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" k * uniforms.M + m" )
59
- << " * " << B.GetByOffset (" k * uniforms.N + n" ) << " ;\n " ;
77
+ shader.MainFunctionBody () << " var col = tile_row_start + local_id.x;\n "
78
+ << " var row = k_start + local_id.y;\n "
79
+ << " if (col < uniforms.M && row < uniforms.K) {\n "
80
+ << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.M + col" ) << " ;\n "
81
+ << " } else {\n "
82
+ << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
83
+ << " }\n\n "
84
+ << " col = tile_col_start + local_id.x;\n "
85
+ << " row = k_start + local_id.y;\n "
86
+ << " if (col < uniforms.N && row < uniforms.K) {\n "
87
+ << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.N + col" ) << " ;\n "
88
+ << " } else {\n "
89
+ << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
90
+ << " }\n " ;
60
91
} else if (!transA_ && transB_) {
61
- shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" m * uniforms.K + k" )
62
- << " * " << B.GetByOffset (" n * uniforms.K + k" ) << " ;\n " ;
92
+ shader.MainFunctionBody () << " var col = k_start + local_id.x;\n "
93
+ << " var row = tile_row_start + local_id.y;\n "
94
+ << " if (col < uniforms.K && row < uniforms.M) {\n "
95
+ << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
96
+ << " } else {\n "
97
+ << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
98
+ << " }\n\n "
99
+ << " col = k_start + local_id.x;\n "
100
+ << " row = tile_col_start + local_id.y;\n "
101
+ << " if (col < uniforms.K && row < uniforms.N) {\n "
102
+ << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
103
+ << " } else {\n "
104
+ << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
105
+ << " }\n " ;
63
106
} else {
64
- shader.MainFunctionBody () << " value = value + " << A.GetByOffset (" m * uniforms.K + k" )
65
- << " * " << B.GetByOffset (" k * uniforms.N + n" ) << " ;\n " ;
107
+ shader.MainFunctionBody () << " var col = k_start + local_id.x;\n "
108
+ << " var row = tile_row_start + local_id.y;\n "
109
+ << " if (col < uniforms.K && row < uniforms.M) {\n "
110
+ << " tile_a[local_id.y][local_id.x] = " << A.GetByOffset (" row * uniforms.K + col" ) << " ;\n "
111
+ << " } else {\n "
112
+ << " tile_a[local_id.y][local_id.x] = output_value_t(0);\n "
113
+ << " }\n\n "
114
+ << " col = tile_col_start + local_id.x;\n "
115
+ << " row = k_start + local_id.y;\n "
116
+ << " if (col < uniforms.N && row < uniforms.K) {\n "
117
+ << " tile_b[local_id.y][local_id.x] = " << B.GetByOffset (" row * uniforms.N + col" ) << " ;\n "
118
+ << " } else {\n "
119
+ << " tile_b[local_id.y][local_id.x] = output_value_t(0);\n "
120
+ << " }\n " ;
66
121
}
67
- shader.MainFunctionBody () << " }\n "
68
- << " \n " ;
122
+
123
+ shader.MainFunctionBody () << " k_start = k_start + " << TILE_SIZE << " u;\n "
124
+ << " workgroupBarrier();\n\n "
125
+ << " for (var k = 0u; k < " << TILE_SIZE << " u; k = k + 1u) {\n " ;
126
+
127
+ if (transA_ && transB_) {
128
+ shader.MainFunctionBody () << " value = value + tile_a[k][local_id.y] * tile_b[local_id.x][k];\n " ;
129
+ } else if (transA_ && !transB_) {
130
+ shader.MainFunctionBody () << " value = value + tile_a[k][local_id.y] * tile_b[k][local_id.x];\n " ;
131
+ } else if (!transA_ && transB_) {
132
+ shader.MainFunctionBody () << " value = value + tile_a[local_id.y][k] * tile_b[local_id.x][k];\n " ;
133
+ } else {
134
+ shader.MainFunctionBody () << " value = value + tile_a[local_id.y][k] * tile_b[k][local_id.x];\n " ;
135
+ }
136
+
137
+ shader.MainFunctionBody () << " }\n "
138
+ << " workgroupBarrier();\n "
139
+ << " }\n\n " ;
69
140
}
70
141
71
142
// Calculate Alpha
@@ -76,11 +147,19 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
76
147
// Calculate Bias
77
148
if (need_handle_bias_) {
78
149
const ShaderVariableHelper& C = shader.AddInput (" C" , ShaderUsage::UseUniform);
79
- shader.MainFunctionBody () << " value = value + output_value_t(uniforms.beta) * "
150
+ shader.MainFunctionBody () << " let m = tile_row_start + local_id.y;\n "
151
+ << " let n = tile_col_start + local_id.x;\n "
152
+ << " value = value + output_value_t(uniforms.beta) * "
80
153
<< C.GetByOffset (C.BroadcastedIndicesToOffset (" vec2(m, n)" , output)) << " ;\n " ;
154
+ } else {
155
+ shader.MainFunctionBody () << " let m = tile_row_start + local_id.y;\n "
156
+ << " let n = tile_col_start + local_id.x;\n " ;
81
157
}
82
158
83
- shader.MainFunctionBody () << output.SetByOffset (" global_idx" , " value" ) << " \n " ;
159
+ // Write output
160
+ shader.MainFunctionBody () << " if (m < uniforms.M && n < uniforms.N) {\n "
161
+ << " " << output.SetByOffset (" m * uniforms.N + n" , " value" ) << " \n "
162
+ << " }\n " ;
84
163
85
164
return Status::OK ();
86
165
}
@@ -132,16 +211,20 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
132
211
program.AddInput ({C, ProgramTensorMetadataDependency::Rank});
133
212
}
134
213
214
+ const uint32_t TILE_SIZE = 16 ;
215
+ const uint32_t num_tile_n = (N + TILE_SIZE - 1 ) / TILE_SIZE;
216
+ const uint32_t num_tile_m = (M + TILE_SIZE - 1 ) / TILE_SIZE;
217
+
135
218
program.AddOutputs ({{Y, ProgramTensorMetadataDependency::Type}})
136
- .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE )
137
- .SetWorkgroupSize (WORKGROUP_SIZE )
219
+ .SetDispatchGroupSize (num_tile_n * num_tile_m )
220
+ .SetWorkgroupSize (TILE_SIZE, TILE_SIZE )
138
221
.AddUniformVariables ({
139
- {static_cast <uint32_t >(output_size )}, // output_size
140
- {static_cast <uint32_t >(M)}, // M
141
- {static_cast <uint32_t >(N)}, // N
142
- {static_cast <uint32_t >(K)}, // K
143
- {alpha_}, // alpha
144
- {beta_} // beta
222
+ {static_cast <uint32_t >(num_tile_n )}, // num_tile_n
223
+ {static_cast <uint32_t >(M)}, // M
224
+ {static_cast <uint32_t >(N)}, // N
225
+ {static_cast <uint32_t >(K)}, // K
226
+ {alpha_}, // alpha
227
+ {beta_} // beta
145
228
});
146
229
147
230
return context.RunProgram (program);
0 commit comments