Skip to content

Axis permutation to correctly handle cycles using bitmask approach #1505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

NewBornRustacean
Copy link
Contributor

Overview

What's up! 😃 Currently, there're two functions to deal with permuting and reversing arrays(permuted_axes and reversed_axes).
As mentioned from related issue, they're not in-place, so I added permute_axes and reverse_axes. The new implementation for permute_axes uses a bitmask-based approach to efficiently track and process cycles in axis permutations!

Related issue

Changes

  1. permute_axes (in-place):

    • A new function with a cycle detection algorithm based on bitmasking.
    • try to maintain the zero-allocation, in-place modification approach
  2. reverse_axes (in-place):

    • simply use &mut self

Details on permute_axes

How it works

The new implementation uses a bitmask to track visited axes during permutation cycles:

  1. Cycle Detection: When permuting axes, we identify cycles of positions that need to be updated.
  2. Bitmask Tracking: A single usize is used as a bitmask to efficiently track which axes have been processed.
  3. In-place Updates: Each cycle is processed in-place by following the permutation pattern.

Example

Consider permuting a 2×2 array from [[1, 2], [3, 4]] with axes permutation [1, 0] (transpose):

// Original array
let mut arr = array![[1, 2], [3, 4]];
// Original: dims = [2, 2], strides = [2, 1]

// Permute axes [1, 0]
arr.permute_axes([1, 0]);
// Result: dims = [2, 2], strides = [1, 2]
// Array is now [[1, 3], [2, 4]]

Here's how the algorithm processes this:

  1. Start with axis 1 (new_axis=0):

    • Store initial values: dim[1]=2, stride[1]=1
    • Follow cycle: 1 → 0 → 1
    • Update dim[1] = dim[0] = 2
    • Update stride[1] = stride[0] = 2
    • Mark axes 0 and 1 as visited
  2. Process axis 0 (new_axis=1):

    • Already visited in the previous cycle, skip

The bitmask approach ensures each axis is processed exactly once, even in complex permutations with multiple cycles. The key operations are:

  • visited |= 1 << axis: Mark axis as visited
  • (visited & (1 << axis)) != 0: Check if axis is visited

@akern40 @nilgoyette added test cases similar to the original functions(premuted_axes and reversed_axes), but there might be missing cases. please take a look when you guys have some time! 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants