Skip to content

Commit 2ff9abd

Browse files
authored
[ET-VK] Manual sync native layer norm (#10242)
As title. The automated squash in #10238 did not sync the changes for the native layer norm shader for some reason.
1 parent 47fb157 commit 2ff9abd

File tree

1 file changed

+208
-123
lines changed

1 file changed

+208
-123
lines changed

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

+208-123
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,71 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
4343
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
4444
const lowp int out_packed_dim = unhash_packed_dim(out_layout);
4545

46-
#define SHARED_MEMORY_FACTOR 2
4746
#define MAX_WORKGROUP_SIZE 64
4847

48+
// Shared memory factor increases shared memory allocation by a scale that should either be 1 or a power of 2.
49+
//
50+
// Increasing factor allows more data to be stored in shared memory and increase thread utilization during reduction.
51+
// Why? Because when performing reduction, the number of active threads becomes half in each iteration.
52+
// Increasing scaling factor increases the thread occupancy and hence utilize the GPU better.
53+
// eg.
54+
// If local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 1, 32 elements will be loaded into shared memory.
55+
// First iteration of reduce will have 16 threads sum up 32 elements.
56+
// Second iteration will have 8 threads sum up 16 elements from previous iteration and so on.
57+
// So thread utilization starts at 50%.
58+
//
59+
// By contrast if local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 2, 64 elements will be loaded into shared memory.
60+
// First iteration of reduce will have 32 threads sum up 64 elements.
61+
// Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
62+
// Thus thread utilization starts at 100%.
63+
#define SHARED_MEMORY_FACTOR 2
64+
4965
#define offset_pos_index(index) ((index) + ((index) >> 2))
5066

5167
shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
5268

53-
// function to reduce input data in workgroup's x dimension
69+
// Function to reduce input data in workgroup's x dimension
70+
//
71+
// The implementation resembles reduction as depicted below
72+
// | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | 2 | 3 | 2 | 7 | 0 | 11 | 0 | 2 | current_stride -> 1
73+
// | / | / | / | / | / | / | / | /
74+
// | / | / | / | / | / | / | / | /
75+
// | / | / | / | / | / | / | / | /
76+
// | 11 | 1 | 9 | 1 | 2 | 2 | 8 | 5 | 5 | 3 | 9 | 7 | 11 | 11 | 2 | 2 | current_stride -> 2
77+
// | / | / | / | /
78+
// | / | / | / | /
79+
// | / | / | / | /
80+
// | 20 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |14 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 4
81+
// | / | /
82+
// | / | /
83+
// | / | /
84+
// | / | /
85+
// | / | /
86+
// | 30 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 8
87+
// | /
88+
// | /
89+
// | /
90+
// | /
91+
// | /
92+
// | /
93+
// | /
94+
// | /
95+
// | /
96+
// | 57 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride = -> 16
97+
//
98+
// Threads access shared index in following pattern
99+
// Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 1
100+
// Shared Index | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | X | X | X | X | X | X | X | X | index *= 1
101+
//
102+
// Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 2
103+
// Shared Index | 0 | 4 | 8 | 12 | X | X | X | X | X | X | X | X | X | X | X | X | index *= 2
104+
//
105+
// Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 4
106+
// Shared Index | 0 | 8 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 4
107+
//
108+
// Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 8
109+
// Shared Index | 0 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 8
110+
54111
void reduce_input(const int width_stride, const int shared_idx_offset) {
55112
// wait for all shared memory writes to finish
56113
memoryBarrierShared();
@@ -70,10 +127,9 @@ void reduce_input(const int width_stride, const int shared_idx_offset) {
70127
}
71128
}
72129

73-
void main() {
130+
void reduce_non_packed_dim() {
74131
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
75132
const int width = int(sizes.x);
76-
77133
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
78134

79135
// width batch read stride
@@ -85,148 +141,177 @@ void main() {
85141
// local memory index for this thread
86142
const int shared_idx = shared_idx_offset + int(gl_LocalInvocationID.x);
87143

88-
// if packed dimension width
89-
if (in_packed_dim != W_DIM) {
90-
VEC4_T mean = VEC4_T(0);
91-
VEC4_T var = VEC4_T(0);
92-
93-
// Loop over the width in stride increments
94-
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
95-
// Read input in shared memory
96-
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
97-
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
98-
99-
VEC4_T in_val = VEC4_T(0);
100-
if (all(lessThan(in_pos, out_limits))) {
101-
in_val = load_texel(t_in, in_pos);
102-
}
103-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
104-
}
144+
VEC4_T mean = VEC4_T(0);
145+
VEC4_T var = VEC4_T(0);
105146

106-
reduce_input(width_stride, shared_idx_offset);
107-
mean += shared_input[offset_pos_index(shared_idx_offset)];
147+
// Loop over the width in stride increments
148+
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
149+
// Read input in shared memory
150+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
151+
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
152+
153+
VEC4_T in_val = VEC4_T(0);
154+
if (all(lessThan(in_pos, out_limits))) {
155+
in_val = load_texel(t_in, in_pos);
156+
}
157+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
108158
}
109159

110-
mean /= width;
160+
reduce_input(width_stride, shared_idx_offset);
161+
mean += shared_input[offset_pos_index(shared_idx_offset)];
162+
}
163+
164+
mean /= width;
111165

112-
// Loop over the width in stride increments
113-
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
114-
// Read input in shared memory
115-
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
116-
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
166+
memoryBarrierShared();
167+
barrier();
117168

118-
VEC4_T in_val = mean;
119-
if (all(lessThan(in_pos, out_limits))) {
120-
in_val = load_texel(t_in, in_pos);
121-
}
169+
// Loop over the width in stride increments
170+
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
171+
// Read input in shared memory
172+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
173+
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
122174

123-
const VEC4_T delta = in_val - mean;
124-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
175+
VEC4_T in_val = mean;
176+
if (all(lessThan(in_pos, out_limits))) {
177+
in_val = load_texel(t_in, in_pos);
125178
}
126179

127-
reduce_input(width_stride, shared_idx_offset);
128-
var += shared_input[offset_pos_index(shared_idx_offset)];
180+
const VEC4_T delta = in_val - mean;
181+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
129182
}
130183

131-
var /= width;
184+
reduce_input(width_stride, shared_idx_offset);
185+
var += shared_input[offset_pos_index(shared_idx_offset)];
186+
}
132187

133-
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
134-
VEC4_T offset = -rstd * mean;
188+
var /= width;
135189

136-
VEC4_T v = load_texel(t_in, lpos);
137-
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0)).xxxx;
138-
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0)).xxxx;
139-
VEC4_T outtex = (v * rstd + offset) * weight + bias;
140-
if (all(lessThan(lpos, out_limits))) {
141-
write_texel_lpos(t_out, ivec3(lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
142-
}
190+
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
191+
VEC4_T offset = -rstd * mean;
143192

144-
if (gl_GlobalInvocationID.x == 0) {
145-
write_texel(t_mean, lpos, mean);
146-
write_texel(t_rstd, lpos, rstd);
147-
}
148-
} else {
149-
const int last_packed_width_index = divup4(width) - 1;
150-
T mean = T(0);
151-
T var = T(0);
152-
const int remain = width & 3;
153-
154-
const int in_pos_x_limit = out_limits[in_axis_map.x];
155-
156-
// Loop over the width in stride increments
157-
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
158-
// Read input in shared memory
159-
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
160-
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
161-
in_pos[in_axis_map.x] = in_pos_x;
162-
163-
VEC4_T in_val = VEC4_T(0);
164-
if (in_pos_x < in_pos_x_limit) {
165-
in_val = load_texel(t_in, in_pos);
166-
}
167-
168-
if (in_pos_x == last_packed_width_index && remain != 0) {
169-
const int remain_inv = 4 - remain;
170-
in_val.y = mix(in_val.y, T(0), remain_inv > 2);
171-
in_val.z = mix(in_val.z, T(0), remain_inv > 1);
172-
in_val.w = mix(in_val.w, T(0), remain_inv > 0);
173-
}
174-
175-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
193+
VEC4_T v = load_texel(t_in, lpos);
194+
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0)).xxxx;
195+
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0)).xxxx;
196+
VEC4_T outtex = (v * rstd + offset) * weight + bias;
197+
198+
if (all(lessThan(lpos, out_limits))) {
199+
write_texel_lpos(t_out, lpos, outtex, out_axis_map);
200+
}
201+
202+
if (gl_GlobalInvocationID.x == 0) {
203+
write_texel(t_mean, lpos, mean);
204+
write_texel(t_rstd, lpos, rstd);
205+
}
206+
}
207+
208+
void reduce_packed_dim() {
209+
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
210+
const int width = int(sizes.x);
211+
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
212+
213+
// width batch read stride
214+
const int width_stride = int(gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
215+
216+
// local memory starting offset for this thread
217+
const int shared_idx_offset = width_stride * int(gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
218+
219+
// local memory index for this thread
220+
const int shared_idx = shared_idx_offset + int(gl_LocalInvocationID.x);
221+
222+
const int last_packed_width_index = divup4(width) - 1;
223+
T mean = T(0);
224+
T var = T(0);
225+
const int remain = width & 3;
226+
227+
const int in_pos_x_limit = out_limits[in_axis_map.x];
228+
229+
// Loop over the width in stride increments
230+
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
231+
// Read input in shared memory
232+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
233+
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
234+
in_pos[in_axis_map.x] = in_pos_x;
235+
236+
VEC4_T in_val = VEC4_T(0);
237+
if (in_pos_x < in_pos_x_limit) {
238+
in_val = load_texel(t_in, in_pos);
176239
}
177240

178-
reduce_input(width_stride, shared_idx_offset);
179-
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
180-
mean += val.x + val.y + val.z + val.w;
241+
if (in_pos_x == last_packed_width_index && remain != 0) {
242+
const int remain_inv = 4 - remain;
243+
in_val.y = mix(in_val.y, T(0), remain_inv > 2);
244+
in_val.z = mix(in_val.z, T(0), remain_inv > 1);
245+
in_val.w = mix(in_val.w, T(0), remain_inv > 0);
246+
}
247+
248+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
181249
}
182250

183-
mean /= width;
184-
185-
// Loop over the width in stride increments
186-
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
187-
// Read input in shared memory
188-
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
189-
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
190-
in_pos[in_axis_map.x] = in_pos_x;
191-
192-
VEC4_T in_val = VEC4_T(mean);
193-
if (in_pos_x < in_pos_x_limit) {
194-
in_val = load_texel(t_in, in_pos);
195-
}
196-
197-
if (in_pos_x == last_packed_width_index && remain != 0) {
198-
const int remain_inv = 4 - remain;
199-
in_val.y = mix(in_val.y, mean.x, remain_inv > 2);
200-
in_val.z = mix(in_val.z, mean.x, remain_inv > 1);
201-
in_val.w = mix(in_val.w, mean.x, remain_inv > 0);
202-
}
203-
204-
const VEC4_T delta = in_val - mean;
205-
const VEC4_T delta2 = delta * delta;
206-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
251+
reduce_input(width_stride, shared_idx_offset);
252+
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253+
mean += val.x + val.y + val.z + val.w;
254+
}
255+
256+
mean /= width;
257+
258+
memoryBarrierShared();
259+
barrier();
260+
261+
// Loop over the width in stride increments
262+
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
263+
// Read input in shared memory
264+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
265+
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
266+
in_pos[in_axis_map.x] = in_pos_x;
267+
268+
VEC4_T in_val = VEC4_T(mean);
269+
if (in_pos_x < in_pos_x_limit) {
270+
in_val = load_texel(t_in, in_pos);
271+
}
272+
273+
if (in_pos_x == last_packed_width_index && remain != 0) {
274+
const int remain_inv = 4 - remain;
275+
in_val.y = mix(in_val.y, mean.x, remain_inv > 2);
276+
in_val.z = mix(in_val.z, mean.x, remain_inv > 1);
277+
in_val.w = mix(in_val.w, mean.x, remain_inv > 0);
207278
}
208279

209-
reduce_input(width_stride, shared_idx_offset);
210-
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
211-
var += val.x + val.y + val.z + val.w;
280+
const VEC4_T delta = in_val - mean;
281+
const VEC4_T delta2 = delta * delta;
282+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
212283
}
213284

214-
var /= width;
285+
reduce_input(width_stride, shared_idx_offset);
286+
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287+
var += val.x + val.y + val.z + val.w;
288+
}
289+
290+
var /= width;
215291

216-
T rstd = pow(var + epsilon, T(-0.5));
217-
T offset = -rstd * mean;
292+
T rstd = pow(var + epsilon, T(-0.5));
293+
T offset = -rstd * mean;
218294

219-
VEC4_T v = load_texel(t_in, lpos);
220-
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0));
221-
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0));
222-
VEC4_T outtex = (v * rstd + offset) * weight + bias;
223-
if (all(lessThan(lpos, out_limits))) {
224-
write_texel_lpos(t_out, ivec3(lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
225-
}
295+
VEC4_T v = load_texel(t_in, lpos);
296+
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0));
297+
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0));
298+
VEC4_T outtex = (v * rstd + offset) * weight + bias;
226299

227-
if (gl_GlobalInvocationID.x == 0) {
228-
write_texel(t_mean, lpos, VEC4_T(mean));
229-
write_texel(t_rstd, lpos, VEC4_T(rstd));
230-
}
300+
if (all(lessThan(lpos, out_limits))) {
301+
write_texel_lpos(t_out, lpos, outtex, out_axis_map);
302+
}
303+
304+
if (gl_GlobalInvocationID.x == 0) {
305+
write_texel(t_mean, lpos, VEC4_T(mean));
306+
write_texel(t_rstd, lpos, VEC4_T(rstd));
307+
}
308+
}
309+
310+
void main() {
311+
// if packed dimension width
312+
if (in_packed_dim != W_DIM) {
313+
reduce_non_packed_dim();
314+
} else {
315+
reduce_packed_dim();
231316
}
232317
}

0 commit comments

Comments
 (0)