@@ -43,106 +43,190 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
4343const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
4444const lowp int out_packed_dim = unhash_packed_dim(out_layout);
4545
46- void main() {
47- const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
46+ #define SHARED_MEMORY_FACTOR 2
47+ #define MAX_WORKGROUP_SIZE 64
48+
49+ #define offset_pos_index(index) ((index) + ((index) >> 2 ))
50+
51+ shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
52+
53+ // function to reduce input data in workgroup's x dimension
54+ void reduce_input(const int width_stride, const int shared_idx_offset) {
55+ // wait for all shared memory writes to finish
56+ memoryBarrierShared();
57+ barrier();
58+
59+ // loop log(width_stride) times
60+ for (int current_stride = 1 , index = int (gl_LocalInvocationID.x << 1 ); current_stride < width_stride; current_stride *= 2 , index <<= 1 ) {
61+ // if the index at this thread is within the width stride
62+ if (index < width_stride) {
63+ const int local_shared_idx = shared_idx_offset + index;
64+ // add the value at current stride to this thread's value
65+ shared_input[offset_pos_index(local_shared_idx)] += shared_input[offset_pos_index(local_shared_idx + current_stride)];
66+ }
4867
49- if ( any ( greaterThanEqual (lpos, out_limits))) {
50- return ;
68+ memoryBarrierShared();
69+ barrier() ;
5170 }
71+ }
5272
73+ void main() {
74+ const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
5375 const int width = int (sizes.x);
5476
77+ ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
78+
79+ // width batch read stride
80+ const int width_stride = int (gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
81+
82+ // local memory starting offset for this thread
83+ const int shared_idx_offset = width_stride * int (gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
84+
85+ // local memory index for this thread
86+ const int shared_idx = shared_idx_offset + int (gl_LocalInvocationID.x);
87+
88+ // if packed dimension width
5589 if (in_packed_dim != W_DIM) {
5690 VEC4_T mean = VEC4_T(0 );
57- VEC4_T delta = VEC4_T(0 );
58- VEC4_T delta2 = VEC4_T(0 );
59- VEC4_T M2 = VEC4_T(0 );
60-
61- // Use Welford's online algorithm to compute mean and variance in one pass
62- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
63- ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
64- for (int w = 0 ; w < width; ++ w) {
65- in_pos[in_axis_map.x] = w;
66- VEC4_T v = load_texel(t_in, in_pos);
67- delta = v - mean;
68- mean += delta / (w + 1 );
69- delta2 = v - mean;
70- M2 += delta * delta2;
91+ VEC4_T var = VEC4_T(0 );
92+
93+ // Loop over the width in stride increments
94+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
95+ // Read input in shared memory
96+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
97+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
98+
99+ VEC4_T in_val = VEC4_T(0 );
100+ if (all (lessThan (in_pos, out_limits))) {
101+ in_val = load_texel(t_in, in_pos);
102+ }
103+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
104+ }
105+
106+ reduce_input(width_stride, shared_idx_offset);
107+ mean += shared_input[offset_pos_index(shared_idx_offset)];
108+ }
109+
110+ mean /= width;
111+
112+ // Loop over the width in stride increments
113+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
114+ // Read input in shared memory
115+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
116+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
117+
118+ VEC4_T in_val = mean;
119+ if (all (lessThan (in_pos, out_limits))) {
120+ in_val = load_texel(t_in, in_pos);
121+ }
122+
123+ const VEC4_T delta = in_val - mean;
124+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
125+ }
126+
127+ reduce_input(width_stride, shared_idx_offset);
128+ var += shared_input[offset_pos_index(shared_idx_offset)];
71129 }
72130
73- VEC4_T var = M2 / width;
131+ var /= width;
132+
74133 VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
75134 VEC4_T offset = - rstd * mean;
76135
77- for (int w = 0 ; w < width; ++ w) {
78- in_pos[in_axis_map.x] = w;
79- VEC4_T v = load_texel(t_in, in_pos);
80- // broadcasting
81- VEC4_T weight = load_texel(t_weight, ivec3 (w, 0 , 0 )).xxxx;
82- VEC4_T bias = load_texel(t_bias, ivec3 (w, 0 , 0 )).xxxx;
83- VEC4_T outtex = (v * rstd + offset) * weight + bias;
84- write_texel_lpos(t_out, ivec3 (w, lpos.y, lpos.z), outtex, out_axis_map);
136+ VEC4_T v = load_texel(t_in, lpos);
137+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 )).xxxx;
138+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 )).xxxx;
139+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
140+ if (all (lessThan (lpos, out_limits))) {
141+ write_texel_lpos(t_out, ivec3 (lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
85142 }
86143
87- write_texel(t_mean, lpos, mean);
88- write_texel(t_rstd, lpos, rstd);
144+ if (gl_GlobalInvocationID.x == 0 ) {
145+ write_texel(t_mean, lpos, mean);
146+ write_texel(t_rstd, lpos, rstd);
147+ }
89148 } else {
90- const int packed_width = divup4(width);
91-
149+ const int last_packed_width_index = divup4(width) - 1 ;
92150 T mean = T(0 );
93- T delta = T(0 );
94- T delta2 = T(0 );
95- T M2 = T(0 );
96- // Use Welford's online algorithm to compute mean and variance in one pass
97- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
98- ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
99- T width_counter = T(1 );
100-
101- const bool has_unaligned_width = (width & 0x3) != 0 ;
102- const int fully_packed_4_comp_count = packed_width - mix (0 , 1 , has_unaligned_width);
103-
104- // iterate through texels that are fully packed ie. has 4 components
105- for (int w = 0 ; w < fully_packed_4_comp_count; ++ w) {
106- in_pos[in_axis_map.x] = w;
107- VEC4_T v = load_texel(t_in, in_pos);
108- for (int i= 0 ; i< 4 ; i++ ) {
109- delta = v[i] - mean;
110- mean += delta / width_counter;
111- delta2 = v[i] - mean;
112- M2 += delta * delta2;
113- width_counter++ ;
151+ T var = T(0 );
152+ const int remain = width & 3 ;
153+
154+ const int in_pos_x_limit = out_limits[in_axis_map.x];
155+
156+ // Loop over the width in stride increments
157+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
158+ // Read input in shared memory
159+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
160+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
161+ in_pos[in_axis_map.x] = in_pos_x;
162+
163+ VEC4_T in_val = VEC4_T(0 );
164+ if (in_pos_x < in_pos_x_limit) {
165+ in_val = load_texel(t_in, in_pos);
166+ }
167+
168+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
169+ const int remain_inv = 4 - remain;
170+ in_val.y = mix (in_val.y, T(0 ), remain_inv > 2 );
171+ in_val.z = mix (in_val.z, T(0 ), remain_inv > 1 );
172+ in_val.w = mix (in_val.w, T(0 ), remain_inv > 0 );
173+ }
174+
175+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
114176 }
177+
178+ reduce_input(width_stride, shared_idx_offset);
179+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
180+ mean += val.x + val.y + val.z + val.w;
115181 }
116182
117- // handle last texel if its not 4 aligned
118- if (has_unaligned_width) {
119- in_pos[in_axis_map.x] = fully_packed_4_comp_count;
120- const int remaining_width = width & 0x3;
121-
122- VEC4_T v = load_texel(t_in, in_pos);
123- for (int i= 0 ; i< remaining_width; i++ ) {
124- delta = v[i] - mean;
125- mean += delta / width_counter;
126- delta2 = v[i] - mean;
127- M2 += delta * delta2;
128- width_counter++ ;
183+ mean /= width;
184+
185+ // Loop over the width in stride increments
186+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
187+ // Read input in shared memory
188+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
189+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
190+ in_pos[in_axis_map.x] = in_pos_x;
191+
192+ VEC4_T in_val = VEC4_T(mean);
193+ if (in_pos_x < in_pos_x_limit) {
194+ in_val = load_texel(t_in, in_pos);
195+ }
196+
197+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
198+ const int remain_inv = 4 - remain;
199+ in_val.y = mix (in_val.y, mean.x, remain_inv > 2 );
200+ in_val.z = mix (in_val.z, mean.x, remain_inv > 1 );
201+ in_val.w = mix (in_val.w, mean.x, remain_inv > 0 );
202+ }
203+
204+ const VEC4_T delta = in_val - mean;
205+ const VEC4_T delta2 = delta * delta;
206+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
129207 }
208+
209+ reduce_input(width_stride, shared_idx_offset);
210+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
211+ var += val.x + val.y + val.z + val.w;
130212 }
131213
132- T var = M2 / (width_counter - 1 );
133- T rstd = inversesqrt (var + epsilon);
214+ var /= width;
215+
216+ T rstd = pow (var + epsilon, T(- 0.5 ));
134217 T offset = - rstd * mean;
135218
136- for (int w = 0 ; w < packed_width; ++ w) {
137- in_pos[in_axis_map.x] = w;
138- VEC4_T v = load_texel(t_in, in_pos);
139- VEC4_T weight = load_texel(t_weight, ivec3 (w, 0 , 0 ));
140- VEC4_T bias = load_texel(t_bias, ivec3 (w, 0 , 0 ));
141- VEC4_T outtex = (v * rstd + offset) * weight + bias;
142- write_texel_lpos(t_out, ivec3 (w, lpos.y, lpos.z), outtex, out_axis_map);
219+ VEC4_T v = load_texel(t_in, lpos);
220+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 ));
221+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 ));
222+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
223+ if (all (lessThan (lpos, out_limits))) {
224+ write_texel_lpos(t_out, ivec3 (lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
143225 }
144226
145- write_texel(t_mean, lpos, VEC4_T(mean));
146- write_texel(t_rstd, lpos, VEC4_T(rstd));
227+ if (gl_GlobalInvocationID.x == 0 ) {
228+ write_texel(t_mean, lpos, VEC4_T(mean));
229+ write_texel(t_rstd, lpos, VEC4_T(rstd));
230+ }
147231 }
148232}
0 commit comments