Skip to content

Commit cef3901

Browse files
authored
Merge pull request #930 from rust-ndarray/faster-sort
Speed up applying the permutation in sort-axis example
2 parents f7a1277 + a3bceb8 commit cef3901

File tree

1 file changed

+70
-31
lines changed

1 file changed

+70
-31
lines changed

examples/sort-axis.rs

+70-31
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use ndarray::prelude::*;
22
use ndarray::{Data, RemoveAxis, Zip};
33

4+
use rawpointer::PointerExt;
5+
46
use std::cmp::Ordering;
57
use std::ptr::copy_nonoverlapping;
68

@@ -97,8 +99,8 @@ where
9799
where
98100
D: RemoveAxis,
99101
{
100-
let axis = axis;
101102
let axis_len = self.len_of(axis);
103+
let axis_stride = self.stride_of(axis);
102104
assert_eq!(axis_len, perm.indices.len());
103105
debug_assert!(perm.correct());
104106

@@ -112,26 +114,48 @@ where
112114
// logically move ownership of all elements from self into result
113115
// the result realizes this ownership at .assume_init() further down
114116
let mut moved_elements = 0;
117+
118+
// the permutation vector is used like this:
119+
//
120+
// index: 0 1 2 3 (index in result)
121+
// permut: 2 3 0 1 (index in the source)
122+
//
123+
// move source 2 -> result 0,
124+
// move source 3 -> result 1,
125+
// move source 0 -> result 2,
126+
// move source 1 -> result 3,
127+
// et.c.
128+
129+
let source_0 = self.raw_view().index_axis_move(axis, 0);
130+
115131
Zip::from(&perm.indices)
116132
.and(result.axis_iter_mut(axis))
117133
.for_each(|&perm_i, result_pane| {
118-
// possible improvement: use unchecked indexing for `index_axis`
134+
// Use a shortcut to avoid bounds checking in `index_axis` for the source.
135+
//
136+
// It works because for any given element pointer in the array we have the
137+
// relationship:
138+
//
139+
// .index_axis(axis, 0) + .stride_of(axis) * j == .index_axis(axis, j)
140+
//
141+
// where + is pointer arithmetic on the element pointers.
142+
//
143+
// Here source_0 and the offset is equivalent to self.index_axis(axis, perm_i)
119144
Zip::from(result_pane)
120-
.and(self.index_axis(axis, perm_i))
121-
.for_each(|to, from| {
145+
.and(source_0.clone())
146+
.for_each(|to, from_0| {
147+
let from = from_0.stride_offset(axis_stride, perm_i);
122148
copy_nonoverlapping(from, to.as_mut_ptr(), 1);
123149
moved_elements += 1;
124150
});
125151
});
126152
debug_assert_eq!(result.len(), moved_elements);
127-
// panic-critical begin: we must not panic
128-
// forget moved array elements but not its vec
129-
// old_storage drops empty
153+
// forget the old elements but not the allocation
130154
let mut old_storage = self.into_raw_vec();
131155
old_storage.set_len(0);
132156

157+
// transfer ownership of the elements into the result
133158
result.assume_init()
134-
// panic-critical end
135159
}
136160
}
137161
}
@@ -179,31 +203,46 @@ mod tests {
179203
[75600.94, 17.],
180204
[75601.06, 18.],
181205
];
206+
let answer = array![
207+
[75600.09, 10.],
208+
[75600.21, 11.],
209+
[75600.45, 13.],
210+
[75600.58, 14.],
211+
[75600.82, 16.],
212+
[75600.94, 17.],
213+
[75601.06, 18.],
214+
[75601.33, 12.],
215+
[107998.96, 1.],
216+
[107999.08, 2.],
217+
[107999.20, 3.],
218+
[107999.45, 5.],
219+
[107999.57, 6.],
220+
[107999.81, 8.],
221+
[107999.94, 9.],
222+
[108000.33, 4.],
223+
[108010.69, 7.],
224+
[109000.70, 15.],
225+
];
226+
227+
// f layout copy of a
228+
let mut af = Array::zeros(a.dim().f());
229+
af.assign(&a);
230+
231+
// transposed copy of a
232+
let at = a.t().to_owned();
182233

234+
// c layout permute
183235
let perm = a.sort_axis_by(Axis(0), |i, j| a[[i, 0]] < a[[j, 0]]);
236+
184237
let b = a.permute_axis(Axis(0), &perm);
185-
assert_eq!(
186-
b,
187-
array![
188-
[75600.09, 10.],
189-
[75600.21, 11.],
190-
[75600.45, 13.],
191-
[75600.58, 14.],
192-
[75600.82, 16.],
193-
[75600.94, 17.],
194-
[75601.06, 18.],
195-
[75601.33, 12.],
196-
[107998.96, 1.],
197-
[107999.08, 2.],
198-
[107999.20, 3.],
199-
[107999.45, 5.],
200-
[107999.57, 6.],
201-
[107999.81, 8.],
202-
[107999.94, 9.],
203-
[108000.33, 4.],
204-
[108010.69, 7.],
205-
[109000.70, 15.],
206-
]
207-
);
238+
assert_eq!(b, answer);
239+
240+
// f layout permute
241+
let bf = af.permute_axis(Axis(0), &perm);
242+
assert_eq!(bf, answer);
243+
244+
// transposed permute
245+
let bt = at.permute_axis(Axis(1), &perm);
246+
assert_eq!(bt, answer.t());
208247
}
209248
}

0 commit comments

Comments
 (0)