@@ -43,106 +43,190 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
43
43
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
44
44
const lowp int out_packed_dim = unhash_packed_dim(out_layout);
45
45
46
- void main() {
47
- const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
46
+ #define SHARED_MEMORY_FACTOR 2
47
+ #define MAX_WORKGROUP_SIZE 64
48
+
49
+ #define offset_pos_index(index) ((index) + ((index) >> 2 ))
50
+
51
+ shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
52
+
53
+ // function to reduce input data in workgroup's x dimension
54
+ void reduce_input(const int width_stride, const int shared_idx_offset) {
55
+ // wait for all shared memory writes to finish
56
+ memoryBarrierShared();
57
+ barrier();
58
+
59
+ // loop log(width_stride) times
60
+ for (int current_stride = 1 , index = int (gl_LocalInvocationID.x << 1 ); current_stride < width_stride; current_stride *= 2 , index <<= 1 ) {
61
+ // if the index at this thread is within the width stride
62
+ if (index < width_stride) {
63
+ const int local_shared_idx = shared_idx_offset + index;
64
+ // add the value at current stride to this thread's value
65
+ shared_input[offset_pos_index(local_shared_idx)] += shared_input[offset_pos_index(local_shared_idx + current_stride)];
66
+ }
48
67
49
- if ( any ( greaterThanEqual (lpos, out_limits))) {
50
- return ;
68
+ memoryBarrierShared();
69
+ barrier() ;
51
70
}
71
+ }
52
72
73
+ void main() {
74
+ const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
53
75
const int width = int (sizes.x);
54
76
77
+ ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
78
+
79
+ // width batch read stride
80
+ const int width_stride = int (gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
81
+
82
+ // local memory starting offset for this thread
83
+ const int shared_idx_offset = width_stride * int (gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
84
+
85
+ // local memory index for this thread
86
+ const int shared_idx = shared_idx_offset + int (gl_LocalInvocationID.x);
87
+
88
+ // if packed dimension width
55
89
if (in_packed_dim != W_DIM) {
56
90
VEC4_T mean = VEC4_T(0 );
57
- VEC4_T delta = VEC4_T(0 );
58
- VEC4_T delta2 = VEC4_T(0 );
59
- VEC4_T M2 = VEC4_T(0 );
60
-
61
- // Use Welford's online algorithm to compute mean and variance in one pass
62
- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
63
- ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
64
- for (int w = 0 ; w < width; ++ w) {
65
- in_pos[in_axis_map.x] = w;
66
- VEC4_T v = load_texel(t_in, in_pos);
67
- delta = v - mean;
68
- mean += delta / (w + 1 );
69
- delta2 = v - mean;
70
- M2 += delta * delta2;
91
+ VEC4_T var = VEC4_T(0 );
92
+
93
+ // Loop over the width in stride increments
94
+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
95
+ // Read input in shared memory
96
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
97
+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
98
+
99
+ VEC4_T in_val = VEC4_T(0 );
100
+ if (all (lessThan (in_pos, out_limits))) {
101
+ in_val = load_texel(t_in, in_pos);
102
+ }
103
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
104
+ }
105
+
106
+ reduce_input(width_stride, shared_idx_offset);
107
+ mean += shared_input[offset_pos_index(shared_idx_offset)];
108
+ }
109
+
110
+ mean /= width;
111
+
112
+ // Loop over the width in stride increments
113
+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
114
+ // Read input in shared memory
115
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
116
+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
117
+
118
+ VEC4_T in_val = mean;
119
+ if (all (lessThan (in_pos, out_limits))) {
120
+ in_val = load_texel(t_in, in_pos);
121
+ }
122
+
123
+ const VEC4_T delta = in_val - mean;
124
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
125
+ }
126
+
127
+ reduce_input(width_stride, shared_idx_offset);
128
+ var += shared_input[offset_pos_index(shared_idx_offset)];
71
129
}
72
130
73
- VEC4_T var = M2 / width;
131
+ var /= width;
132
+
74
133
VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
75
134
VEC4_T offset = - rstd * mean;
76
135
77
- for (int w = 0 ; w < width; ++ w) {
78
- in_pos[in_axis_map.x] = w;
79
- VEC4_T v = load_texel(t_in, in_pos);
80
- // broadcasting
81
- VEC4_T weight = load_texel(t_weight, ivec3 (w, 0 , 0 )).xxxx;
82
- VEC4_T bias = load_texel(t_bias, ivec3 (w, 0 , 0 )).xxxx;
83
- VEC4_T outtex = (v * rstd + offset) * weight + bias;
84
- write_texel_lpos(t_out, ivec3 (w, lpos.y, lpos.z), outtex, out_axis_map);
136
+ VEC4_T v = load_texel(t_in, lpos);
137
+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 )).xxxx;
138
+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 )).xxxx;
139
+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
140
+ if (all (lessThan (lpos, out_limits))) {
141
+ write_texel_lpos(t_out, ivec3 (lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
85
142
}
86
143
87
- write_texel(t_mean, lpos, mean);
88
- write_texel(t_rstd, lpos, rstd);
144
+ if (gl_GlobalInvocationID.x == 0 ) {
145
+ write_texel(t_mean, lpos, mean);
146
+ write_texel(t_rstd, lpos, rstd);
147
+ }
89
148
} else {
90
- const int packed_width = divup4(width);
91
-
149
+ const int last_packed_width_index = divup4(width) - 1 ;
92
150
T mean = T(0 );
93
- T delta = T(0 );
94
- T delta2 = T(0 );
95
- T M2 = T(0 );
96
- // Use Welford's online algorithm to compute mean and variance in one pass
97
- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
98
- ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
99
- T width_counter = T(1 );
100
-
101
- const bool has_unaligned_width = (width & 0x3) != 0 ;
102
- const int fully_packed_4_comp_count = packed_width - mix (0 , 1 , has_unaligned_width);
103
-
104
- // iterate through texels that are fully packed ie. has 4 components
105
- for (int w = 0 ; w < fully_packed_4_comp_count; ++ w) {
106
- in_pos[in_axis_map.x] = w;
107
- VEC4_T v = load_texel(t_in, in_pos);
108
- for (int i= 0 ; i< 4 ; i++ ) {
109
- delta = v[i] - mean;
110
- mean += delta / width_counter;
111
- delta2 = v[i] - mean;
112
- M2 += delta * delta2;
113
- width_counter++ ;
151
+ T var = T(0 );
152
+ const int remain = width & 3 ;
153
+
154
+ const int in_pos_x_limit = out_limits[in_axis_map.x];
155
+
156
+ // Loop over the width in stride increments
157
+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
158
+ // Read input in shared memory
159
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
160
+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
161
+ in_pos[in_axis_map.x] = in_pos_x;
162
+
163
+ VEC4_T in_val = VEC4_T(0 );
164
+ if (in_pos_x < in_pos_x_limit) {
165
+ in_val = load_texel(t_in, in_pos);
166
+ }
167
+
168
+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
169
+ const int remain_inv = 4 - remain;
170
+ in_val.y = mix (in_val.y, T(0 ), remain_inv > 2 );
171
+ in_val.z = mix (in_val.z, T(0 ), remain_inv > 1 );
172
+ in_val.w = mix (in_val.w, T(0 ), remain_inv > 0 );
173
+ }
174
+
175
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
114
176
}
177
+
178
+ reduce_input(width_stride, shared_idx_offset);
179
+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
180
+ mean += val.x + val.y + val.z + val.w;
115
181
}
116
182
117
- // handle last texel if its not 4 aligned
118
- if (has_unaligned_width) {
119
- in_pos[in_axis_map.x] = fully_packed_4_comp_count;
120
- const int remaining_width = width & 0x3;
121
-
122
- VEC4_T v = load_texel(t_in, in_pos);
123
- for (int i= 0 ; i< remaining_width; i++ ) {
124
- delta = v[i] - mean;
125
- mean += delta / width_counter;
126
- delta2 = v[i] - mean;
127
- M2 += delta * delta2;
128
- width_counter++ ;
183
+ mean /= width;
184
+
185
+ // Loop over the width in stride increments
186
+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
187
+ // Read input in shared memory
188
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
189
+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
190
+ in_pos[in_axis_map.x] = in_pos_x;
191
+
192
+ VEC4_T in_val = VEC4_T(mean);
193
+ if (in_pos_x < in_pos_x_limit) {
194
+ in_val = load_texel(t_in, in_pos);
195
+ }
196
+
197
+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
198
+ const int remain_inv = 4 - remain;
199
+ in_val.y = mix (in_val.y, mean.x, remain_inv > 2 );
200
+ in_val.z = mix (in_val.z, mean.x, remain_inv > 1 );
201
+ in_val.w = mix (in_val.w, mean.x, remain_inv > 0 );
202
+ }
203
+
204
+ const VEC4_T delta = in_val - mean;
205
+ const VEC4_T delta2 = delta * delta;
206
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
129
207
}
208
+
209
+ reduce_input(width_stride, shared_idx_offset);
210
+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
211
+ var += val.x + val.y + val.z + val.w;
130
212
}
131
213
132
- T var = M2 / (width_counter - 1 );
133
- T rstd = inversesqrt (var + epsilon);
214
+ var /= width;
215
+
216
+ T rstd = pow (var + epsilon, T(- 0.5 ));
134
217
T offset = - rstd * mean;
135
218
136
- for (int w = 0 ; w < packed_width; ++ w) {
137
- in_pos[in_axis_map.x] = w;
138
- VEC4_T v = load_texel(t_in, in_pos);
139
- VEC4_T weight = load_texel(t_weight, ivec3 (w, 0 , 0 ));
140
- VEC4_T bias = load_texel(t_bias, ivec3 (w, 0 , 0 ));
141
- VEC4_T outtex = (v * rstd + offset) * weight + bias;
142
- write_texel_lpos(t_out, ivec3 (w, lpos.y, lpos.z), outtex, out_axis_map);
219
+ VEC4_T v = load_texel(t_in, lpos);
220
+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 ));
221
+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 ));
222
+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
223
+ if (all (lessThan (lpos, out_limits))) {
224
+ write_texel_lpos(t_out, ivec3 (lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
143
225
}
144
226
145
- write_texel(t_mean, lpos, VEC4_T(mean));
146
- write_texel(t_rstd, lpos, VEC4_T(rstd));
227
+ if (gl_GlobalInvocationID.x == 0 ) {
228
+ write_texel(t_mean, lpos, VEC4_T(mean));
229
+ write_texel(t_rstd, lpos, VEC4_T(rstd));
230
+ }
147
231
}
148
232
}
0 commit comments