Skip to content

tensor reorientation and deformation gradient tests #27

@stnava

Description

@stnava

@cookpa i'd like to have a self-contained unit-test for these reorientation implementations. something that we can run as part of integration tests etc.

#26

please see below for an example of what we might do. this provides a test of the py_based implementation.

import ants
import numpy as np
import time

# ==============================================================================
# SECTION 1: THE TWO IMPLEMENTATIONS TO BE COMPARED
# ==============================================================================

def get_physical_space_grid(image: ants.ANTsImage) -> ants.ANTsImage:
    """
    Creates a physical space grid image from a reference image.

    The output is a multi-component ANTsImage where each voxel's value
    is its physical coordinate (e.g., in millimeters). This is achieved by
    calculating the physical coordinates for each voxel, creating a separate
    scalar image for each coordinate dimension (X, Y, Z), and then merging
    them into a single vector image using `ants.merge_channels`.

    This function is a from-scratch implementation of the concept of a
    physical space grid.

    Parameters
    ----------
    image : ants.ANTsImage
        A reference ANTsImage that defines the grid's dimensions, spacing,
        origin, and direction.

    Returns
    -------
    ants.ANTsImage
        A multi-component image of the same size as the input, where each
        voxel contains its (x, y, z) physical space coordinate.
    """
    dim = image.dimension
    shape = image.shape

    # 1. Create a grid of voxel indices (i, j, k)
    # We use np.meshgrid with 'ij' indexing to match the matrix/array layout.
    # This creates a list of arrays, one for each dimension's index.
    indices_1d = [np.arange(s) for s in shape]
    mesh = np.meshgrid(*indices_1d, indexing='ij')

    # Stack the meshgrid arrays to get a single NumPy array of shape
    # (nx, ny, nz, 3) where the last dimension holds the (i, j, k) index.
    index_grid_np = np.stack(mesh, axis=-1)

    # 2. Apply the voxel-to-physical-space transformation formula:
    # P = Origin + Direction @ (Spacing * I)
    
    # a. Scale indices by spacing. Broadcasting handles this element-wise.
    # Resulting shape is still (nx, ny, nz, 3)
    scaled_indices = index_grid_np * image.spacing
    
    # b. Rotate the scaled indices by the direction matrix.
    # We use np.einsum for a clean, vectorized matrix multiplication at each voxel.
    # 'ij,...j->...i' means: for each voxel vector (...j), multiply it by the
    # direction matrix (ij) to get a new rotated vector (...i).
    rotated_scaled_indices = np.einsum('ij,...j->...i', image.direction, scaled_indices)

    # c. Add the origin. Broadcasting handles this.
    physical_grid_np = rotated_scaled_indices + image.origin

    # 3. Create a separate ANTsImage for each physical coordinate component (X, Y, Z)
    # as required by the implementation constraints.
    component_images = []
    for d in range(dim):
        # Extract the coordinate component (e.g., all X values)
        component_array = physical_grid_np[..., d]
        
        # Create a scalar ANTsImage from this component, ensuring it has the
        # same header information as the original image.
        component_image = ants.from_numpy(
            component_array,
            origin=image.origin,
            spacing=image.spacing,
            direction=image.direction
        )
        component_images.append(component_image)

    # 4. Merge the component images into a single multi-component (vector) image.
    physical_grid_image = ants.merge_channels(component_images)
    
    return physical_grid_image

def _polar_decomposition_numpy(matrix):
    """
    A minimal, numpy-based polar decomposition to extract the rotation matrix Z.
    This helper ensures the script is self-contained. X = ZP => Z = U @ Vh
    Includes reflection correction to ensure det(Z) = +1.
    """
    U, s, Vh = np.linalg.svd(matrix)
    Z = U @ Vh
    # Correct for reflections to ensure it's a proper rotation matrix
    if np.linalg.det(Z) < 0:
        Vh[-1, :] *= -1
        Z = U @ Vh
    return Z

def deformation_gradient_original_py(warp_image, to_rotation=False, to_inverse_rotation=False):
    """
    The ORIGINAL, slow, loop-based pure Python implementation.
    This serves as the "gold standard" for our verification.
    """
    if not ants.is_image(warp_image):
        raise RuntimeError("antsimage is required")

    dim = warp_image.dimension
    warpnp = warp_image.numpy()
    tshp = warp_image.shape
    tdir = warp_image.direction
    spc = warp_image.spacing
    ident = np.eye(dim)

    # Calculate gradient
    dg_list = []
    for k in range(dim):
        # The original code had an if/else for dim, but *spc works for both
        temp = np.stack(np.gradient(warpnp[..., k], *spc, axis=range(dim)), axis=dim)
        dg_list.append(temp)
    dg = np.stack(dg_list, axis=dim + 1)

    # Loop through every single voxel to apply transforms
    it = np.ndindex(tshp)
    for i in it:
        # Transpose to match ITK Jacobian definition
        mat = dg[i].T
        
        # Transform rows by the direction matrix
        # dg_new = tdir @ dg_old
        mat_rotated = np.zeros_like(mat)
        for r in range(dim):
            mat_rotated[r, :] = np.dot(tdir, mat[r, :])
        
        # Add identity to form full deformation gradient
        mat_final = mat_rotated + ident
        dg[i] = mat_final

    if to_rotation or to_inverse_rotation:
        it_rot = np.ndindex(tshp)
        for i in it_rot:
            # Perform polar decomposition on each voxel's matrix
            rot_matrix = _polar_decomposition_numpy(dg[i])
            if to_inverse_rotation:
                dg[i] = rot_matrix.T
            else:
                dg[i] = rot_matrix

    newshape = tshp + (dim * dim,)
    dg_reshaped = np.reshape(dg, newshape)
    return ants.from_numpy(dg_reshaped, origin=warp_image.origin,
                           spacing=warp_image.spacing, direction=warp_image.direction,
                           has_components=True)


def deformation_gradient_optimized(warp_image, to_rotation=False, to_inverse_rotation=False):
    """
    The NEW, fast, vectorized pure Python/NumPy implementation.
    """
    if not ants.is_image(warp_image):
        raise RuntimeError("antsimage is required")
    dim = warp_image.dimension
    tshp = warp_image.shape
    tdir = warp_image.direction
    spc = warp_image.spacing
    warpnp = warp_image.numpy()
    gradient_list = [np.gradient(warpnp[..., k], *spc, axis=range(dim)) for k in range(dim)]
    # This correctly calculates J.T, where dg[..., i, j] = d(u_j)/d(x_i)
    dg = np.stack([np.stack(grad_k, axis=-1) for grad_k in gradient_list], axis=-1)
    
    # *** THE FIX IS HERE ***
    # The original loop was equivalent to (tdir @ J.T).T
    # Since our `dg` is J.T, we need to compute (tdir @ dg).T
    # 1. Compute temp = tdir @ dg
    temp = np.einsum('ij,...jk->...ik', tdir, dg)
    # 2. Transpose the result
    axes = (*range(temp.ndim - 2), temp.ndim - 1, temp.ndim - 2)
    dg = np.transpose(temp, axes=axes)
    
    dg += np.eye(dim)
    if to_rotation or to_inverse_rotation:
        U, s, Vh = np.linalg.svd(dg)
        Z = U @ Vh
        dets = np.linalg.det(Z)
        reflection_mask = dets < 0
        Vh[reflection_mask, -1, :] *= -1
        Z[reflection_mask] = U[reflection_mask] @ Vh[reflection_mask]
        dg = Z
        if to_inverse_rotation:
            dg = np.transpose(dg, axes=(*range(dg.ndim - 2), dg.ndim - 1, dg.ndim - 2))
    new_shape = tshp + (dim * dim,)
    dg_reshaped = np.reshape(dg, new_shape)
    return ants.from_numpy(dg_reshaped, origin=warp_image.origin,
                           spacing=warp_image.spacing, direction=warp_image.direction,
                           has_components=True)

# ==============================================================================
# SECTION 2: VERIFICATION FRAMEWORK
# ==============================================================================

def create_test_image_and_warp(shape=(10, 12, 14), spacing=(1.5, 1.2, 1.0)):
    """
    Creates a reference image with a non-identity direction matrix
    and a corresponding smooth deformation field.
    """
    print("--- Step 1: Creating test data ---")
    img = ants.make_image(shape, spacing=spacing)
    theta = np.deg2rad(30)
    cos_t, sin_t = np.cos(theta), np.sin(theta)
    rotation_matrix = np.array([
        [cos_t, -sin_t, 0.1],
        [sin_t,  cos_t, 0.2],
        [-0.1,    -0.2, 0.97] # Non-trivial rotation
    ])
    rotation_matrix, _ = np.linalg.qr(rotation_matrix) # Ensure it's perfectly orthogonal
    img.set_direction(rotation_matrix)
    print("Set reference image with direction matrix:\n", np.round(img.direction, 3))

    # *** THIS IS THE CORRECTED LINE ***
    grid = get_physical_space_grid(img)
    print( grid )
    gridL = ants.split_channels( grid )

    warp_field_np = np.zeros((*shape, 3))
    warp_field_np[..., 0] = np.sin(gridL[0].numpy() / 5) * 2.0
    warp_field_np[..., 1] = np.cos(gridL[1].numpy() / 6) * 1.5
    warp_field_np[..., 2] = np.sin(gridL[2].numpy() / 4) * 3.0
    warp_img = ants.from_numpy(
        warp_field_np, origin=img.origin, spacing=img.spacing,
        direction=img.direction, has_components=True
    )
    print("Successfully created a simulated warp field.")
    return warp_img

def run_verification_test():
    """Main function to run the verification test."""
    warp_image = create_test_image_and_warp()
    tolerance = 1e-6

    test_configs = [
        {'name': 'Standard Deformation Gradient', 'kwargs': {'to_rotation': False, 'to_inverse_rotation': False}},
        {'name': 'Rotation Matrix', 'kwargs': {'to_rotation': True, 'to_inverse_rotation': False}},
        {'name': 'Inverse Rotation Matrix', 'kwargs': {'to_rotation': False, 'to_inverse_rotation': True}}
    ]

    for config in test_configs:
        print(f"\n--- Test Case: {config['name']} ---")

        # Run original, slow implementation
        print("Running original (slow) implementation...")
        start_time_orig = time.time()
        result_orig = deformation_gradient_original_py(warp_image, **config['kwargs'])
        time_orig = time.time() - start_time_orig
        print(f"Original finished in {time_orig:.4f} seconds.")

        # Run new, optimized implementation
        print("Running optimized implementation...")
        start_time_opt = time.time()
        result_opt = deformation_gradient_optimized(warp_image, **config['kwargs'])
        time_opt = time.time() - start_time_opt
        print(f"Optimized finished in {time_opt:.4f} seconds.")

        # Compare results and assert correctness
        are_close = np.allclose(result_orig.numpy(), result_opt.numpy(), atol=tolerance)
        speedup = time_orig / time_opt if time_opt > 0 else float('inf')
        
        print(f"\nVerification: {'PASSED' if are_close else 'FAILED'}")
        print(f"Performance Gain: {speedup:.2f}x")
        assert are_close, f"Mismatch found in test case: {config['name']}"

    print("\n\nAll tests passed successfully! The optimized implementation is correct and significantly faster.")

if __name__ == "__main__":
    run_verification_test()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions