Skip to content

Commit 9a3b555

Browse files
SS-JIAtrivedivivek
andauthored
[ET-VK] Manual sync to fbsource (#10238)
Contains commits squashed from #10117 PR stack. Differential Revision: [D72866962](https://our.internmc.facebook.com/intern/diff/D72866962/) Differential Revision: [D72862490](https://our.internmc.facebook.com/intern/diff/D72862490/) Differential Revision: [D72581293](https://our.internmc.facebook.com/intern/diff/D72581293/) Differential Revision: [D72430290](https://our.internmc.facebook.com/intern/diff/D72430290/) @diff-train-skip-merge Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 9af1043 commit 9a3b555

File tree

6 files changed

+228
-280
lines changed

6 files changed

+228
-280
lines changed

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

+159-75
Original file line numberDiff line numberDiff line change
@@ -43,106 +43,190 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
4343
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
4444
const lowp int out_packed_dim = unhash_packed_dim(out_layout);
4545

46-
void main() {
47-
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
46+
#define SHARED_MEMORY_FACTOR 2
47+
#define MAX_WORKGROUP_SIZE 64
48+
49+
#define offset_pos_index(index) ((index) + ((index) >> 2))
50+
51+
shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
52+
53+
// function to reduce input data in workgroup's x dimension
54+
void reduce_input(const int width_stride, const int shared_idx_offset) {
55+
// wait for all shared memory writes to finish
56+
memoryBarrierShared();
57+
barrier();
58+
59+
// loop log(width_stride) times
60+
for (int current_stride = 1, index = int(gl_LocalInvocationID.x << 1); current_stride < width_stride; current_stride *= 2, index <<= 1) {
61+
// if the index at this thread is within the width stride
62+
if (index < width_stride) {
63+
const int local_shared_idx = shared_idx_offset + index;
64+
// add the value at current stride to this thread's value
65+
shared_input[offset_pos_index(local_shared_idx)] += shared_input[offset_pos_index(local_shared_idx + current_stride)];
66+
}
4867

49-
if (any(greaterThanEqual(lpos, out_limits))) {
50-
return;
68+
memoryBarrierShared();
69+
barrier();
5170
}
71+
}
5272

73+
void main() {
74+
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
5375
const int width = int(sizes.x);
5476

77+
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
78+
79+
// width batch read stride
80+
const int width_stride = int(gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
81+
82+
// local memory starting offset for this thread
83+
const int shared_idx_offset = width_stride * int(gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
84+
85+
// local memory index for this thread
86+
const int shared_idx = shared_idx_offset + int(gl_LocalInvocationID.x);
87+
88+
// if packed dimension width
5589
if (in_packed_dim != W_DIM) {
5690
VEC4_T mean = VEC4_T(0);
57-
VEC4_T delta = VEC4_T(0);
58-
VEC4_T delta2 = VEC4_T(0);
59-
VEC4_T M2 = VEC4_T(0);
60-
61-
// Use Welford's online algorithm to compute mean and variance in one pass
62-
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
63-
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
64-
for (int w = 0; w < width; ++w) {
65-
in_pos[in_axis_map.x] = w;
66-
VEC4_T v = load_texel(t_in, in_pos);
67-
delta = v - mean;
68-
mean += delta / (w + 1);
69-
delta2 = v - mean;
70-
M2 += delta * delta2;
91+
VEC4_T var = VEC4_T(0);
92+
93+
// Loop over the width in stride increments
94+
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
95+
// Read input in shared memory
96+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
97+
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
98+
99+
VEC4_T in_val = VEC4_T(0);
100+
if (all(lessThan(in_pos, out_limits))) {
101+
in_val = load_texel(t_in, in_pos);
102+
}
103+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
104+
}
105+
106+
reduce_input(width_stride, shared_idx_offset);
107+
mean += shared_input[offset_pos_index(shared_idx_offset)];
108+
}
109+
110+
mean /= width;
111+
112+
// Loop over the width in stride increments
113+
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
114+
// Read input in shared memory
115+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
116+
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
117+
118+
VEC4_T in_val = mean;
119+
if (all(lessThan(in_pos, out_limits))) {
120+
in_val = load_texel(t_in, in_pos);
121+
}
122+
123+
const VEC4_T delta = in_val - mean;
124+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
125+
}
126+
127+
reduce_input(width_stride, shared_idx_offset);
128+
var += shared_input[offset_pos_index(shared_idx_offset)];
71129
}
72130

73-
VEC4_T var = M2 / width;
131+
var /= width;
132+
74133
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
75134
VEC4_T offset = -rstd * mean;
76135

77-
for (int w = 0; w < width; ++w) {
78-
in_pos[in_axis_map.x] = w;
79-
VEC4_T v = load_texel(t_in, in_pos);
80-
// broadcasting
81-
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;
82-
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;
83-
VEC4_T outtex = (v * rstd + offset) * weight + bias;
84-
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
136+
VEC4_T v = load_texel(t_in, lpos);
137+
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0)).xxxx;
138+
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0)).xxxx;
139+
VEC4_T outtex = (v * rstd + offset) * weight + bias;
140+
if (all(lessThan(lpos, out_limits))) {
141+
write_texel_lpos(t_out, ivec3(lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
85142
}
86143

87-
write_texel(t_mean, lpos, mean);
88-
write_texel(t_rstd, lpos, rstd);
144+
if (gl_GlobalInvocationID.x == 0) {
145+
write_texel(t_mean, lpos, mean);
146+
write_texel(t_rstd, lpos, rstd);
147+
}
89148
} else {
90-
const int packed_width = divup4(width);
91-
149+
const int last_packed_width_index = divup4(width) - 1;
92150
T mean = T(0);
93-
T delta = T(0);
94-
T delta2 = T(0);
95-
T M2 = T(0);
96-
// Use Welford's online algorithm to compute mean and variance in one pass
97-
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
98-
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
99-
T width_counter = T(1);
100-
101-
const bool has_unaligned_width = (width & 0x3) != 0;
102-
const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width);
103-
104-
// iterate through texels that are fully packed ie. has 4 components
105-
for (int w = 0; w < fully_packed_4_comp_count; ++w) {
106-
in_pos[in_axis_map.x] = w;
107-
VEC4_T v = load_texel(t_in, in_pos);
108-
for (int i=0; i<4; i++) {
109-
delta = v[i] - mean;
110-
mean += delta / width_counter;
111-
delta2 = v[i] - mean;
112-
M2 += delta * delta2;
113-
width_counter++;
151+
T var = T(0);
152+
const int remain = width & 3;
153+
154+
const int in_pos_x_limit = out_limits[in_axis_map.x];
155+
156+
// Loop over the width in stride increments
157+
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
158+
// Read input in shared memory
159+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
160+
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
161+
in_pos[in_axis_map.x] = in_pos_x;
162+
163+
VEC4_T in_val = VEC4_T(0);
164+
if (in_pos_x < in_pos_x_limit) {
165+
in_val = load_texel(t_in, in_pos);
166+
}
167+
168+
if (in_pos_x == last_packed_width_index && remain != 0) {
169+
const int remain_inv = 4 - remain;
170+
in_val.y = mix(in_val.y, T(0), remain_inv > 2);
171+
in_val.z = mix(in_val.z, T(0), remain_inv > 1);
172+
in_val.w = mix(in_val.w, T(0), remain_inv > 0);
173+
}
174+
175+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
114176
}
177+
178+
reduce_input(width_stride, shared_idx_offset);
179+
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
180+
mean += val.x + val.y + val.z + val.w;
115181
}
116182

117-
// handle last texel if its not 4 aligned
118-
if (has_unaligned_width) {
119-
in_pos[in_axis_map.x] = fully_packed_4_comp_count;
120-
const int remaining_width = width & 0x3;
121-
122-
VEC4_T v = load_texel(t_in, in_pos);
123-
for (int i=0; i<remaining_width; i++) {
124-
delta = v[i] - mean;
125-
mean += delta / width_counter;
126-
delta2 = v[i] - mean;
127-
M2 += delta * delta2;
128-
width_counter++;
183+
mean /= width;
184+
185+
// Loop over the width in stride increments
186+
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
187+
// Read input in shared memory
188+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
189+
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
190+
in_pos[in_axis_map.x] = in_pos_x;
191+
192+
VEC4_T in_val = VEC4_T(mean);
193+
if (in_pos_x < in_pos_x_limit) {
194+
in_val = load_texel(t_in, in_pos);
195+
}
196+
197+
if (in_pos_x == last_packed_width_index && remain != 0) {
198+
const int remain_inv = 4 - remain;
199+
in_val.y = mix(in_val.y, mean.x, remain_inv > 2);
200+
in_val.z = mix(in_val.z, mean.x, remain_inv > 1);
201+
in_val.w = mix(in_val.w, mean.x, remain_inv > 0);
202+
}
203+
204+
const VEC4_T delta = in_val - mean;
205+
const VEC4_T delta2 = delta * delta;
206+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
129207
}
208+
209+
reduce_input(width_stride, shared_idx_offset);
210+
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
211+
var += val.x + val.y + val.z + val.w;
130212
}
131213

132-
T var = M2 / (width_counter - 1);
133-
T rstd = inversesqrt(var + epsilon);
214+
var /= width;
215+
216+
T rstd = pow(var + epsilon, T(-0.5));
134217
T offset = -rstd * mean;
135218

136-
for (int w = 0; w < packed_width; ++w) {
137-
in_pos[in_axis_map.x] = w;
138-
VEC4_T v = load_texel(t_in, in_pos);
139-
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0));
140-
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0));
141-
VEC4_T outtex = (v * rstd + offset) * weight + bias;
142-
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
219+
VEC4_T v = load_texel(t_in, lpos);
220+
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0));
221+
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0));
222+
VEC4_T outtex = (v * rstd + offset) * weight + bias;
223+
if (all(lessThan(lpos, out_limits))) {
224+
write_texel_lpos(t_out, ivec3(lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
143225
}
144226

145-
write_texel(t_mean, lpos, VEC4_T(mean));
146-
write_texel(t_rstd, lpos, VEC4_T(rstd));
227+
if (gl_GlobalInvocationID.x == 0) {
228+
write_texel(t_mean, lpos, VEC4_T(mean));
229+
write_texel(t_rstd, lpos, VEC4_T(rstd));
230+
}
147231
}
148232
}

backends/vulkan/runtime/graph/ops/glsl/permute.glsl

+9-5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ layout(push_constant) uniform PRECISION restrict Block {
3131
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3232
layout(constant_id = 3) const int packed_dim = C_DIM;
3333

34+
#extension GL_EXT_control_flow_attributes : require
35+
3436
void main() {
3537
ivec3 pos = ivec3(gl_GlobalInvocationID);
3638

@@ -54,11 +56,16 @@ void main() {
5456
in_bchw_pos[out_ndims[2]] = pos.y;
5557
in_bchw_pos[out_ndims[3]] = pos.x;
5658

57-
for (int j = 0; j < 4; ++j) {
59+
const int in_packed_dim_size = in_sizes[3 - out_ndims[in_packed_dim_bchw_index]];
60+
61+
[[unroll]] for (int j = 0, bchw_index = in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]; j < 4; ++j, ++bchw_index) {
5862
// terminate the loop if trying to access input texture out of bounds
59-
if (any(greaterThanEqual(in_bchw_pos.wzyx, in_sizes.xyzw))) {
63+
if (bchw_index >= in_packed_dim_size) {
6064
break;
6165
}
66+
// go to position in the input, that is mapped to the packed dim in the output
67+
in_bchw_pos[out_ndims[in_packed_dim_bchw_index]] = bchw_index;
68+
6269
ivec3 fetch_pos;
6370

6471
fetch_pos.xy = in_bchw_pos.wz;
@@ -74,9 +81,6 @@ void main() {
7481
// fetch input texel
7582
VEC4_T inval = VEC4_T(load_texel(t_in, fetch_pos));
7683
outval[j] = inval[in_packed_dim_lane_index];
77-
78-
// go to next position in the input, that is mapped to the packed dim in the output
79-
in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]++;
8084
}
8185

8286
pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]);

0 commit comments

Comments
 (0)