@@ -43,14 +43,71 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
43
43
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
44
44
const lowp int out_packed_dim = unhash_packed_dim(out_layout);
45
45
46
- #define SHARED_MEMORY_FACTOR 2
47
46
#define MAX_WORKGROUP_SIZE 64
48
47
48
+ // Shared memory factor increases shared memory allocation by a scale that should either be 1 or a power of 2.
49
+ //
50
+ // Increasing factor allows more data to be stored in shared memory and increase thread utilization during reduction.
51
+ // Why? Because when performing reduction, the number of active threads becomes half in each iteration.
52
+ // Increasing scaling factor increases the thread occupancy and hence utilize the GPU better.
53
+ // eg.
54
+ // If local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 1, 32 elements will be loaded into shared memory.
55
+ // First iteration of reduce will have 16 threads sum up 32 elements.
56
+ // Second iteration will have 8 threads sum up 16 elements from previous iteration and so on.
57
+ // So thread utilization starts at 50%.
58
+ //
59
+ // By contrast if local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 2, 64 elements will be loaded into shared memory.
60
+ // First iteration of reduce will have 32 threads sum up 64 elements.
61
+ // Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
62
+ // Thus thread utilization starts at 100%.
63
+ #define SHARED_MEMORY_FACTOR 2
64
+
49
65
#define offset_pos_index(index) ((index) + ((index) >> 2 ))
50
66
51
67
shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
52
68
53
- // function to reduce input data in workgroup's x dimension
69
+ // Function to reduce input data in workgroup's x dimension
70
+ //
71
+ // The implementation resembles reduction as depicted below
72
+ // | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | 2 | 3 | 2 | 7 | 0 | 11 | 0 | 2 | current_stride -> 1
73
+ // | / | / | / | / | / | / | / | /
74
+ // | / | / | / | / | / | / | / | /
75
+ // | / | / | / | / | / | / | / | /
76
+ // | 11 | 1 | 9 | 1 | 2 | 2 | 8 | 5 | 5 | 3 | 9 | 7 | 11 | 11 | 2 | 2 | current_stride -> 2
77
+ // | / | / | / | /
78
+ // | / | / | / | /
79
+ // | / | / | / | /
80
+ // | 20 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |14 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 4
81
+ // | / | /
82
+ // | / | /
83
+ // | / | /
84
+ // | / | /
85
+ // | / | /
86
+ // | 30 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 8
87
+ // | /
88
+ // | /
89
+ // | /
90
+ // | /
91
+ // | /
92
+ // | /
93
+ // | /
94
+ // | /
95
+ // | /
96
+ // | 57 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride = -> 16
97
+ //
98
+ // Threads access shared index in following pattern
99
+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 1
100
+ // Shared Index | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | X | X | X | X | X | X | X | X | index *= 1
101
+ //
102
+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 2
103
+ // Shared Index | 0 | 4 | 8 | 12 | X | X | X | X | X | X | X | X | X | X | X | X | index *= 2
104
+ //
105
+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 4
106
+ // Shared Index | 0 | 8 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 4
107
+ //
108
+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 8
109
+ // Shared Index | 0 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 8
110
+
54
111
void reduce_input(const int width_stride, const int shared_idx_offset) {
55
112
// wait for all shared memory writes to finish
56
113
memoryBarrierShared();
@@ -70,10 +127,9 @@ void reduce_input(const int width_stride, const int shared_idx_offset) {
70
127
}
71
128
}
72
129
73
- void main () {
130
+ void reduce_non_packed_dim () {
74
131
const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
75
132
const int width = int (sizes.x);
76
-
77
133
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
78
134
79
135
// width batch read stride
@@ -85,148 +141,177 @@ void main() {
85
141
// local memory index for this thread
86
142
const int shared_idx = shared_idx_offset + int (gl_LocalInvocationID.x);
87
143
88
- // if packed dimension width
89
- if (in_packed_dim != W_DIM) {
90
- VEC4_T mean = VEC4_T(0 );
91
- VEC4_T var = VEC4_T(0 );
92
-
93
- // Loop over the width in stride increments
94
- for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
95
- // Read input in shared memory
96
- for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
97
- in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
98
-
99
- VEC4_T in_val = VEC4_T(0 );
100
- if (all (lessThan (in_pos, out_limits))) {
101
- in_val = load_texel(t_in, in_pos);
102
- }
103
- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
104
- }
144
+ VEC4_T mean = VEC4_T(0 );
145
+ VEC4_T var = VEC4_T(0 );
105
146
106
- reduce_input(width_stride, shared_idx_offset);
107
- mean += shared_input[offset_pos_index(shared_idx_offset)];
147
+ // Loop over the width in stride increments
148
+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
149
+ // Read input in shared memory
150
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
151
+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
152
+
153
+ VEC4_T in_val = VEC4_T(0 );
154
+ if (all (lessThan (in_pos, out_limits))) {
155
+ in_val = load_texel(t_in, in_pos);
156
+ }
157
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
108
158
}
109
159
110
- mean /= width;
160
+ reduce_input(width_stride, shared_idx_offset);
161
+ mean += shared_input[offset_pos_index(shared_idx_offset)];
162
+ }
163
+
164
+ mean /= width;
111
165
112
- // Loop over the width in stride increments
113
- for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
114
- // Read input in shared memory
115
- for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
116
- in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
166
+ memoryBarrierShared();
167
+ barrier();
117
168
118
- VEC4_T in_val = mean;
119
- if (all (lessThan (in_pos, out_limits))) {
120
- in_val = load_texel(t_in, in_pos);
121
- }
169
+ // Loop over the width in stride increments
170
+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
171
+ // Read input in shared memory
172
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
173
+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
122
174
123
- const VEC4_T delta = in_val - mean;
124
- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
175
+ VEC4_T in_val = mean;
176
+ if (all (lessThan (in_pos, out_limits))) {
177
+ in_val = load_texel(t_in, in_pos);
125
178
}
126
179
127
- reduce_input(width_stride, shared_idx_offset) ;
128
- var += shared_input[offset_pos_index(shared_idx_offset)] ;
180
+ const VEC4_T delta = in_val - mean ;
181
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta ;
129
182
}
130
183
131
- var /= width;
184
+ reduce_input(width_stride, shared_idx_offset);
185
+ var += shared_input[offset_pos_index(shared_idx_offset)];
186
+ }
132
187
133
- VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
134
- VEC4_T offset = - rstd * mean;
188
+ var /= width;
135
189
136
- VEC4_T v = load_texel(t_in, lpos);
137
- VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 )).xxxx;
138
- VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 )).xxxx;
139
- VEC4_T outtex = (v * rstd + offset) * weight + bias;
140
- if (all (lessThan (lpos, out_limits))) {
141
- write_texel_lpos(t_out, ivec3 (lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
142
- }
190
+ VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
191
+ VEC4_T offset = - rstd * mean;
143
192
144
- if (gl_GlobalInvocationID.x == 0 ) {
145
- write_texel(t_mean, lpos, mean);
146
- write_texel(t_rstd, lpos, rstd);
147
- }
148
- } else {
149
- const int last_packed_width_index = divup4(width) - 1 ;
150
- T mean = T(0 );
151
- T var = T(0 );
152
- const int remain = width & 3 ;
153
-
154
- const int in_pos_x_limit = out_limits[in_axis_map.x];
155
-
156
- // Loop over the width in stride increments
157
- for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
158
- // Read input in shared memory
159
- for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
160
- const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
161
- in_pos[in_axis_map.x] = in_pos_x;
162
-
163
- VEC4_T in_val = VEC4_T(0 );
164
- if (in_pos_x < in_pos_x_limit) {
165
- in_val = load_texel(t_in, in_pos);
166
- }
167
-
168
- if (in_pos_x == last_packed_width_index && remain != 0 ) {
169
- const int remain_inv = 4 - remain;
170
- in_val.y = mix (in_val.y, T(0 ), remain_inv > 2 );
171
- in_val.z = mix (in_val.z, T(0 ), remain_inv > 1 );
172
- in_val.w = mix (in_val.w, T(0 ), remain_inv > 0 );
173
- }
174
-
175
- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
193
+ VEC4_T v = load_texel(t_in, lpos);
194
+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 )).xxxx;
195
+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 )).xxxx;
196
+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
197
+
198
+ if (all (lessThan (lpos, out_limits))) {
199
+ write_texel_lpos(t_out, lpos, outtex, out_axis_map);
200
+ }
201
+
202
+ if (gl_GlobalInvocationID.x == 0 ) {
203
+ write_texel(t_mean, lpos, mean);
204
+ write_texel(t_rstd, lpos, rstd);
205
+ }
206
+ }
207
+
208
+ void reduce_packed_dim() {
209
+ const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
210
+ const int width = int (sizes.x);
211
+ ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
212
+
213
+ // width batch read stride
214
+ const int width_stride = int (gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
215
+
216
+ // local memory starting offset for this thread
217
+ const int shared_idx_offset = width_stride * int (gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
218
+
219
+ // local memory index for this thread
220
+ const int shared_idx = shared_idx_offset + int (gl_LocalInvocationID.x);
221
+
222
+ const int last_packed_width_index = divup4(width) - 1 ;
223
+ T mean = T(0 );
224
+ T var = T(0 );
225
+ const int remain = width & 3 ;
226
+
227
+ const int in_pos_x_limit = out_limits[in_axis_map.x];
228
+
229
+ // Loop over the width in stride increments
230
+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
231
+ // Read input in shared memory
232
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
233
+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
234
+ in_pos[in_axis_map.x] = in_pos_x;
235
+
236
+ VEC4_T in_val = VEC4_T(0 );
237
+ if (in_pos_x < in_pos_x_limit) {
238
+ in_val = load_texel(t_in, in_pos);
176
239
}
177
240
178
- reduce_input(width_stride, shared_idx_offset);
179
- const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
180
- mean += val.x + val.y + val.z + val.w;
241
+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
242
+ const int remain_inv = 4 - remain;
243
+ in_val.y = mix (in_val.y, T(0 ), remain_inv > 2 );
244
+ in_val.z = mix (in_val.z, T(0 ), remain_inv > 1 );
245
+ in_val.w = mix (in_val.w, T(0 ), remain_inv > 0 );
246
+ }
247
+
248
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
181
249
}
182
250
183
- mean /= width;
184
-
185
- // Loop over the width in stride increments
186
- for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
187
- // Read input in shared memory
188
- for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
189
- const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
190
- in_pos[in_axis_map.x] = in_pos_x;
191
-
192
- VEC4_T in_val = VEC4_T(mean);
193
- if (in_pos_x < in_pos_x_limit) {
194
- in_val = load_texel(t_in, in_pos);
195
- }
196
-
197
- if (in_pos_x == last_packed_width_index && remain != 0 ) {
198
- const int remain_inv = 4 - remain;
199
- in_val.y = mix (in_val.y, mean.x, remain_inv > 2 );
200
- in_val.z = mix (in_val.z, mean.x, remain_inv > 1 );
201
- in_val.w = mix (in_val.w, mean.x, remain_inv > 0 );
202
- }
203
-
204
- const VEC4_T delta = in_val - mean;
205
- const VEC4_T delta2 = delta * delta;
206
- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
251
+ reduce_input(width_stride, shared_idx_offset);
252
+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253
+ mean += val.x + val.y + val.z + val.w;
254
+ }
255
+
256
+ mean /= width;
257
+
258
+ memoryBarrierShared();
259
+ barrier();
260
+
261
+ // Loop over the width in stride increments
262
+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
263
+ // Read input in shared memory
264
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
265
+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
266
+ in_pos[in_axis_map.x] = in_pos_x;
267
+
268
+ VEC4_T in_val = VEC4_T(mean);
269
+ if (in_pos_x < in_pos_x_limit) {
270
+ in_val = load_texel(t_in, in_pos);
271
+ }
272
+
273
+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
274
+ const int remain_inv = 4 - remain;
275
+ in_val.y = mix (in_val.y, mean.x, remain_inv > 2 );
276
+ in_val.z = mix (in_val.z, mean.x, remain_inv > 1 );
277
+ in_val.w = mix (in_val.w, mean.x, remain_inv > 0 );
207
278
}
208
279
209
- reduce_input(width_stride, shared_idx_offset) ;
210
- const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)] ;
211
- var += val.x + val.y + val.z + val.w ;
280
+ const VEC4_T delta = in_val - mean ;
281
+ const VEC4_T delta2 = delta * delta ;
282
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2 ;
212
283
}
213
284
214
- var /= width;
285
+ reduce_input(width_stride, shared_idx_offset);
286
+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287
+ var += val.x + val.y + val.z + val.w;
288
+ }
289
+
290
+ var /= width;
215
291
216
- T rstd = pow (var + epsilon, T(- 0.5 ));
217
- T offset = - rstd * mean;
292
+ T rstd = pow (var + epsilon, T(- 0.5 ));
293
+ T offset = - rstd * mean;
218
294
219
- VEC4_T v = load_texel(t_in, lpos);
220
- VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 ));
221
- VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 ));
222
- VEC4_T outtex = (v * rstd + offset) * weight + bias;
223
- if (all (lessThan (lpos, out_limits))) {
224
- write_texel_lpos(t_out, ivec3 (lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
225
- }
295
+ VEC4_T v = load_texel(t_in, lpos);
296
+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 ));
297
+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 ));
298
+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
226
299
227
- if (gl_GlobalInvocationID.x == 0 ) {
228
- write_texel(t_mean, lpos, VEC4_T(mean));
229
- write_texel(t_rstd, lpos, VEC4_T(rstd));
230
- }
300
+ if (all (lessThan (lpos, out_limits))) {
301
+ write_texel_lpos(t_out, lpos, outtex, out_axis_map);
302
+ }
303
+
304
+ if (gl_GlobalInvocationID.x == 0 ) {
305
+ write_texel(t_mean, lpos, VEC4_T(mean));
306
+ write_texel(t_rstd, lpos, VEC4_T(rstd));
307
+ }
308
+ }
309
+
310
+ void main() {
311
+ // if packed dimension width
312
+ if (in_packed_dim != W_DIM) {
313
+ reduce_non_packed_dim();
314
+ } else {
315
+ reduce_packed_dim();
231
316
}
232
317
}
0 commit comments