Skip to content

Commit 59013cd

Browse files
blussjturner314
andcommitted
move_into: Implement inner-dimension optimization in move-into
If the innermost dimension is contiguous, we can skip it in one go, and save some work in the dropping loop in move_into. Co-authored-by: Jim Turner <[email protected]>
1 parent b103515 commit 59013cd

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

src/impl_owned_array.rs

+21-6
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,6 @@ impl<A, D> Array<A, D>
255255
debug_assert!(data_ptr <= array_memory_head_ptr);
256256
debug_assert!(array_memory_head_ptr <= data_end_ptr);
257257

258-
// iter is a raw pointer iterator traversing self_ in its standard order
259-
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
260-
let mut dropped_elements = 0;
261-
262258
// The idea is simply this: the iterator will yield the elements of self_ in
263259
// increasing address order.
264260
//
@@ -267,6 +263,25 @@ impl<A, D> Array<A, D>
267263
//
268264
// We have to drop elements in the range from `data_ptr` until (not including)
269265
// `data_end_ptr`, except those that are produced by `iter`.
266+
267+
// As an optimization, the innermost axis is removed if it has stride 1, because
268+
// we then have a long stretch of contiguous elements we can skip as one.
269+
let inner_lane_len;
270+
if self_.ndim() > 1 && self_.strides.last_elem() == 1 {
271+
self_.dim.slice_mut().rotate_right(1);
272+
self_.strides.slice_mut().rotate_right(1);
273+
inner_lane_len = self_.dim[0];
274+
self_.dim[0] = 1;
275+
self_.strides[0] = 1;
276+
} else {
277+
inner_lane_len = 1;
278+
}
279+
280+
// iter is a raw pointer iterator traversing the array in memory order now with the
281+
// sorted axes.
282+
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
283+
let mut dropped_elements = 0;
284+
270285
let mut last_ptr = data_ptr;
271286

272287
while let Some(elem_ptr) = iter.next() {
@@ -278,8 +293,8 @@ impl<A, D> Array<A, D>
278293
last_ptr = last_ptr.add(1);
279294
dropped_elements += 1;
280295
}
281-
// Next interval will continue one past the current element
282-
last_ptr = elem_ptr.add(1);
296+
// Next interval will continue one past the current lane
297+
last_ptr = elem_ptr.add(inner_lane_len);
283298
}
284299

285300
while last_ptr < data_end_ptr {

0 commit comments

Comments
 (0)