@@ -148,9 +148,18 @@ namespace gridtools {
148
148
: m_total_lengths{dims}, m_strides(strides) {
149
149
150
150
// We guess the padded lengths from the dimensions and the strides. Assume, that the strides are sorted,
151
- // e.g., [256, 16, 1], and the dimensions are [5, 9, 9]. For the largest stride, we assume that padding
152
- // = dimension (e.g. in this example the i-padding is 5). For all others we can calculate the padding from
153
- // the strides (e.g. in this example, the j-padding is 256 / 16 = 16, and the k-padding is 16 / 1 = 1).
151
+ // e.g., [1, 16, 256], and the dimensions are [5, 9, 9]. For the largest stride (256), we assume that
152
+ // padding = dimension (e.g. in this example the j-padding is 5). For all others we can calculate the
153
+ // padding from the strides (e.g. in this example, the i-padding is 256 / 16 = 16, and the k-padding is 16 /
154
+ // 1 = 16). Note that there might be strides which are set to 0 (masked dimensions).
155
+ //
156
+ // We first create a sorted copy of this array. We then loop over the unsorted array and set the padded
157
+ // length for each entry as follows:
158
+ // - If the stride is masked, the padded length is 0.
159
+ // - If the stride is the maximum stride (i.e., 256 in the example above), the padding is derived from the
160
+ // dimension.
161
+ // - Otherwise, we find the stride s in the sorted array and we look for the next larger stride l in the
162
+ // sorted array. The padded length is then set to l / s. Note that strides might appear several times.
154
163
auto sorted_strides = strides;
155
164
for (uint_t i = 0 ; i < ndims; ++i)
156
165
for (uint_t j = i + 1 ; j < ndims; ++j)
@@ -166,7 +175,11 @@ namespace gridtools {
166
175
else if (strides[i] == 0 ) {
167
176
m_padded_lengths[i] = 0 ;
168
177
} else {
169
- for (int j = i; j < ndims; ++j)
178
+ int i_in_sorted_stride = 0 ;
179
+ for (; i_in_sorted_stride < ndims; ++i_in_sorted_stride)
180
+ if (strides[i] == sorted_strides[i_in_sorted_stride])
181
+ break ;
182
+ for (int j = i_in_sorted_stride; j < ndims; ++j)
170
183
if (strides[i] != sorted_strides[j]) {
171
184
m_padded_lengths[i] = sorted_strides[j] / strides[i];
172
185
break ;
0 commit comments