Skip to content

Commit

Permalink
FEAT: Add dimension merge function to merge contiguous axes
Browse files Browse the repository at this point in the history
  • Loading branch information
bluss committed Mar 31, 2024
1 parent f13c63e commit d42ee96
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,32 @@ where D: Dimension
}
}

/// Attempt to merge axes if possible, starting from the back
///
/// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
/// to merge all axes one by one into Axis(3); when/if this fails,
/// it attempts to merge the rest of the axes together into the next
/// axis in line, for example a result could be:
///
/// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
/// mean axes were merged.
pub(crate) fn merge_axes_from_the_back<D>(dim: &mut D, strides: &mut D)
where D: Dimension
{
debug_assert_eq!(dim.ndim(), strides.ndim());
match dim.ndim() {
0 | 1 => {}
n => {
let mut last = n - 1;
for i in (0..last).rev() {
if !merge_axes(dim, strides, Axis(i), Axis(last)) {
last = i;
}
}
}
}
}

/// Move the axis which has the smallest absolute stride and a length
/// greater than one to be the last axis.
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
Expand Down Expand Up @@ -820,6 +846,30 @@ where D: Dimension
*strides = new_strides;
}

/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
/// stride
///
/// The axes are sorted according to the .abs() of their stride.
pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
where D: Dimension
{
debug_assert!(dim.ndim() > 1);
debug_assert_eq!(dim.ndim(), strides.ndim());
// bubble sort axes
let mut changed = true;
while changed {
changed = false;
for i in 0..dim.ndim() - 1 {
// make sure higher stride axes sort before.
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
changed = true;
dim.slice_mut().swap(i, i + 1);
strides.slice_mut().swap(i, i + 1);
}
}
}
}

#[cfg(test)]
mod test
{
Expand All @@ -829,6 +879,7 @@ mod test
can_index_slice_not_custom,
extended_gcd,
max_abs_offset_check_overflow,
merge_axes_from_the_back,
slice_min_max,
slices_intersect,
solve_linear_diophantine_eq,
Expand Down Expand Up @@ -1213,4 +1264,27 @@ mod test
assert_eq!(d, dans);
assert_eq!(s, sans);
}

#[test]
fn test_merge_axes_from_the_back()
{
let dyndim = Dim::<&[usize]>;

let mut d = Dim([3, 4, 5]);
let mut s = Dim([20, 5, 1]);
merge_axes_from_the_back(&mut d, &mut s);
assert_eq!(d, Dim([1, 1, 60]));
assert_eq!(s, Dim([20, 5, 1]));

let mut d = Dim([3, 4, 5, 2]);
let mut s = Dim([80, 20, 2, 1]);
merge_axes_from_the_back(&mut d, &mut s);
assert_eq!(d, Dim([1, 12, 1, 10]));
assert_eq!(s, Dim([80, 20, 2, 1]));
let mut d = d.into_dyn();
let mut s = s.into_dyn();
squeeze(&mut d, &mut s);
assert_eq!(d, dyndim(&[12, 10]));
assert_eq!(s, dyndim(&[20, 1]));
}
}

0 comments on commit d42ee96

Please sign in to comment.