Skip to content

Commit a3bceb8

Browse files
committed
FEAT: Speed up applying the permutation in sort-axis example
Speed it up by avoiding bounds checking when looking up the pane to move in the source array. This works because for any given element pointer in the array we have the relationship: .index_axis(axis, 0) + .stride_of(axis) * j == .index_axis(axis, j) where + is pointer arithmetic on the element pointers.
1 parent f7a1277 commit a3bceb8

File tree

1 file changed

+70
-31
lines changed

1 file changed

+70
-31
lines changed

examples/sort-axis.rs

+70-31
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use ndarray::prelude::*;
22
use ndarray::{Data, RemoveAxis, Zip};
33

4+
use rawpointer::PointerExt;
5+
46
use std::cmp::Ordering;
57
use std::ptr::copy_nonoverlapping;
68

@@ -97,8 +99,8 @@ where
9799
where
98100
D: RemoveAxis,
99101
{
100-
let axis = axis;
101102
let axis_len = self.len_of(axis);
103+
let axis_stride = self.stride_of(axis);
102104
assert_eq!(axis_len, perm.indices.len());
103105
debug_assert!(perm.correct());
104106

@@ -112,26 +114,48 @@ where
112114
// logically move ownership of all elements from self into result
113115
// the result realizes this ownership at .assume_init() further down
114116
let mut moved_elements = 0;
117+
118+
// the permutation vector is used like this:
119+
//
120+
// index: 0 1 2 3 (index in result)
121+
// permut: 2 3 0 1 (index in the source)
122+
//
123+
// move source 2 -> result 0,
124+
// move source 3 -> result 1,
125+
// move source 0 -> result 2,
126+
// move source 1 -> result 3,
127+
// et.c.
128+
129+
let source_0 = self.raw_view().index_axis_move(axis, 0);
130+
115131
Zip::from(&perm.indices)
116132
.and(result.axis_iter_mut(axis))
117133
.for_each(|&perm_i, result_pane| {
118-
// possible improvement: use unchecked indexing for `index_axis`
134+
// Use a shortcut to avoid bounds checking in `index_axis` for the source.
135+
//
136+
// It works because for any given element pointer in the array we have the
137+
// relationship:
138+
//
139+
// .index_axis(axis, 0) + .stride_of(axis) * j == .index_axis(axis, j)
140+
//
141+
// where + is pointer arithmetic on the element pointers.
142+
//
143+
// Here source_0 and the offset is equivalent to self.index_axis(axis, perm_i)
119144
Zip::from(result_pane)
120-
.and(self.index_axis(axis, perm_i))
121-
.for_each(|to, from| {
145+
.and(source_0.clone())
146+
.for_each(|to, from_0| {
147+
let from = from_0.stride_offset(axis_stride, perm_i);
122148
copy_nonoverlapping(from, to.as_mut_ptr(), 1);
123149
moved_elements += 1;
124150
});
125151
});
126152
debug_assert_eq!(result.len(), moved_elements);
127-
// panic-critical begin: we must not panic
128-
// forget moved array elements but not its vec
129-
// old_storage drops empty
153+
// forget the old elements but not the allocation
130154
let mut old_storage = self.into_raw_vec();
131155
old_storage.set_len(0);
132156

157+
// transfer ownership of the elements into the result
133158
result.assume_init()
134-
// panic-critical end
135159
}
136160
}
137161
}
@@ -179,31 +203,46 @@ mod tests {
179203
[75600.94, 17.],
180204
[75601.06, 18.],
181205
];
206+
let answer = array![
207+
[75600.09, 10.],
208+
[75600.21, 11.],
209+
[75600.45, 13.],
210+
[75600.58, 14.],
211+
[75600.82, 16.],
212+
[75600.94, 17.],
213+
[75601.06, 18.],
214+
[75601.33, 12.],
215+
[107998.96, 1.],
216+
[107999.08, 2.],
217+
[107999.20, 3.],
218+
[107999.45, 5.],
219+
[107999.57, 6.],
220+
[107999.81, 8.],
221+
[107999.94, 9.],
222+
[108000.33, 4.],
223+
[108010.69, 7.],
224+
[109000.70, 15.],
225+
];
226+
227+
// f layout copy of a
228+
let mut af = Array::zeros(a.dim().f());
229+
af.assign(&a);
230+
231+
// transposed copy of a
232+
let at = a.t().to_owned();
182233

234+
// c layout permute
183235
let perm = a.sort_axis_by(Axis(0), |i, j| a[[i, 0]] < a[[j, 0]]);
236+
184237
let b = a.permute_axis(Axis(0), &perm);
185-
assert_eq!(
186-
b,
187-
array![
188-
[75600.09, 10.],
189-
[75600.21, 11.],
190-
[75600.45, 13.],
191-
[75600.58, 14.],
192-
[75600.82, 16.],
193-
[75600.94, 17.],
194-
[75601.06, 18.],
195-
[75601.33, 12.],
196-
[107998.96, 1.],
197-
[107999.08, 2.],
198-
[107999.20, 3.],
199-
[107999.45, 5.],
200-
[107999.57, 6.],
201-
[107999.81, 8.],
202-
[107999.94, 9.],
203-
[108000.33, 4.],
204-
[108010.69, 7.],
205-
[109000.70, 15.],
206-
]
207-
);
238+
assert_eq!(b, answer);
239+
240+
// f layout permute
241+
let bf = af.permute_axis(Axis(0), &perm);
242+
assert_eq!(bf, answer);
243+
244+
// transposed permute
245+
let bt = at.permute_axis(Axis(1), &perm);
246+
assert_eq!(bt, answer.t());
208247
}
209248
}

0 commit comments

Comments
 (0)