generated from ANTsX/ANTsPyT1w
-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
@cookpa i'd like to have a self-contained unit-test for these reorientation implementations. something that we can run as part of integration tests etc.
please see below for an example of what we might do. this provides a test of the py_based implementation.
import ants
import numpy as np
import time
# ==============================================================================
# SECTION 1: THE TWO IMPLEMENTATIONS TO BE COMPARED
# ==============================================================================
def get_physical_space_grid(image: ants.ANTsImage) -> ants.ANTsImage:
"""
Creates a physical space grid image from a reference image.
The output is a multi-component ANTsImage where each voxel's value
is its physical coordinate (e.g., in millimeters). This is achieved by
calculating the physical coordinates for each voxel, creating a separate
scalar image for each coordinate dimension (X, Y, Z), and then merging
them into a single vector image using `ants.merge_channels`.
This function is a from-scratch implementation of the concept of a
physical space grid.
Parameters
----------
image : ants.ANTsImage
A reference ANTsImage that defines the grid's dimensions, spacing,
origin, and direction.
Returns
-------
ants.ANTsImage
A multi-component image of the same size as the input, where each
voxel contains its (x, y, z) physical space coordinate.
"""
dim = image.dimension
shape = image.shape
# 1. Create a grid of voxel indices (i, j, k)
# We use np.meshgrid with 'ij' indexing to match the matrix/array layout.
# This creates a list of arrays, one for each dimension's index.
indices_1d = [np.arange(s) for s in shape]
mesh = np.meshgrid(*indices_1d, indexing='ij')
# Stack the meshgrid arrays to get a single NumPy array of shape
# (nx, ny, nz, 3) where the last dimension holds the (i, j, k) index.
index_grid_np = np.stack(mesh, axis=-1)
# 2. Apply the voxel-to-physical-space transformation formula:
# P = Origin + Direction @ (Spacing * I)
# a. Scale indices by spacing. Broadcasting handles this element-wise.
# Resulting shape is still (nx, ny, nz, 3)
scaled_indices = index_grid_np * image.spacing
# b. Rotate the scaled indices by the direction matrix.
# We use np.einsum for a clean, vectorized matrix multiplication at each voxel.
# 'ij,...j->...i' means: for each voxel vector (...j), multiply it by the
# direction matrix (ij) to get a new rotated vector (...i).
rotated_scaled_indices = np.einsum('ij,...j->...i', image.direction, scaled_indices)
# c. Add the origin. Broadcasting handles this.
physical_grid_np = rotated_scaled_indices + image.origin
# 3. Create a separate ANTsImage for each physical coordinate component (X, Y, Z)
# as required by the implementation constraints.
component_images = []
for d in range(dim):
# Extract the coordinate component (e.g., all X values)
component_array = physical_grid_np[..., d]
# Create a scalar ANTsImage from this component, ensuring it has the
# same header information as the original image.
component_image = ants.from_numpy(
component_array,
origin=image.origin,
spacing=image.spacing,
direction=image.direction
)
component_images.append(component_image)
# 4. Merge the component images into a single multi-component (vector) image.
physical_grid_image = ants.merge_channels(component_images)
return physical_grid_image
def _polar_decomposition_numpy(matrix):
"""
A minimal, numpy-based polar decomposition to extract the rotation matrix Z.
This helper ensures the script is self-contained. X = ZP => Z = U @ Vh
Includes reflection correction to ensure det(Z) = +1.
"""
U, s, Vh = np.linalg.svd(matrix)
Z = U @ Vh
# Correct for reflections to ensure it's a proper rotation matrix
if np.linalg.det(Z) < 0:
Vh[-1, :] *= -1
Z = U @ Vh
return Z
def deformation_gradient_original_py(warp_image, to_rotation=False, to_inverse_rotation=False):
"""
The ORIGINAL, slow, loop-based pure Python implementation.
This serves as the "gold standard" for our verification.
"""
if not ants.is_image(warp_image):
raise RuntimeError("antsimage is required")
dim = warp_image.dimension
warpnp = warp_image.numpy()
tshp = warp_image.shape
tdir = warp_image.direction
spc = warp_image.spacing
ident = np.eye(dim)
# Calculate gradient
dg_list = []
for k in range(dim):
# The original code had an if/else for dim, but *spc works for both
temp = np.stack(np.gradient(warpnp[..., k], *spc, axis=range(dim)), axis=dim)
dg_list.append(temp)
dg = np.stack(dg_list, axis=dim + 1)
# Loop through every single voxel to apply transforms
it = np.ndindex(tshp)
for i in it:
# Transpose to match ITK Jacobian definition
mat = dg[i].T
# Transform rows by the direction matrix
# dg_new = tdir @ dg_old
mat_rotated = np.zeros_like(mat)
for r in range(dim):
mat_rotated[r, :] = np.dot(tdir, mat[r, :])
# Add identity to form full deformation gradient
mat_final = mat_rotated + ident
dg[i] = mat_final
if to_rotation or to_inverse_rotation:
it_rot = np.ndindex(tshp)
for i in it_rot:
# Perform polar decomposition on each voxel's matrix
rot_matrix = _polar_decomposition_numpy(dg[i])
if to_inverse_rotation:
dg[i] = rot_matrix.T
else:
dg[i] = rot_matrix
newshape = tshp + (dim * dim,)
dg_reshaped = np.reshape(dg, newshape)
return ants.from_numpy(dg_reshaped, origin=warp_image.origin,
spacing=warp_image.spacing, direction=warp_image.direction,
has_components=True)
def deformation_gradient_optimized(warp_image, to_rotation=False, to_inverse_rotation=False):
"""
The NEW, fast, vectorized pure Python/NumPy implementation.
"""
if not ants.is_image(warp_image):
raise RuntimeError("antsimage is required")
dim = warp_image.dimension
tshp = warp_image.shape
tdir = warp_image.direction
spc = warp_image.spacing
warpnp = warp_image.numpy()
gradient_list = [np.gradient(warpnp[..., k], *spc, axis=range(dim)) for k in range(dim)]
# This correctly calculates J.T, where dg[..., i, j] = d(u_j)/d(x_i)
dg = np.stack([np.stack(grad_k, axis=-1) for grad_k in gradient_list], axis=-1)
# *** THE FIX IS HERE ***
# The original loop was equivalent to (tdir @ J.T).T
# Since our `dg` is J.T, we need to compute (tdir @ dg).T
# 1. Compute temp = tdir @ dg
temp = np.einsum('ij,...jk->...ik', tdir, dg)
# 2. Transpose the result
axes = (*range(temp.ndim - 2), temp.ndim - 1, temp.ndim - 2)
dg = np.transpose(temp, axes=axes)
dg += np.eye(dim)
if to_rotation or to_inverse_rotation:
U, s, Vh = np.linalg.svd(dg)
Z = U @ Vh
dets = np.linalg.det(Z)
reflection_mask = dets < 0
Vh[reflection_mask, -1, :] *= -1
Z[reflection_mask] = U[reflection_mask] @ Vh[reflection_mask]
dg = Z
if to_inverse_rotation:
dg = np.transpose(dg, axes=(*range(dg.ndim - 2), dg.ndim - 1, dg.ndim - 2))
new_shape = tshp + (dim * dim,)
dg_reshaped = np.reshape(dg, new_shape)
return ants.from_numpy(dg_reshaped, origin=warp_image.origin,
spacing=warp_image.spacing, direction=warp_image.direction,
has_components=True)
# ==============================================================================
# SECTION 2: VERIFICATION FRAMEWORK
# ==============================================================================
def create_test_image_and_warp(shape=(10, 12, 14), spacing=(1.5, 1.2, 1.0)):
"""
Creates a reference image with a non-identity direction matrix
and a corresponding smooth deformation field.
"""
print("--- Step 1: Creating test data ---")
img = ants.make_image(shape, spacing=spacing)
theta = np.deg2rad(30)
cos_t, sin_t = np.cos(theta), np.sin(theta)
rotation_matrix = np.array([
[cos_t, -sin_t, 0.1],
[sin_t, cos_t, 0.2],
[-0.1, -0.2, 0.97] # Non-trivial rotation
])
rotation_matrix, _ = np.linalg.qr(rotation_matrix) # Ensure it's perfectly orthogonal
img.set_direction(rotation_matrix)
print("Set reference image with direction matrix:\n", np.round(img.direction, 3))
# *** THIS IS THE CORRECTED LINE ***
grid = get_physical_space_grid(img)
print( grid )
gridL = ants.split_channels( grid )
warp_field_np = np.zeros((*shape, 3))
warp_field_np[..., 0] = np.sin(gridL[0].numpy() / 5) * 2.0
warp_field_np[..., 1] = np.cos(gridL[1].numpy() / 6) * 1.5
warp_field_np[..., 2] = np.sin(gridL[2].numpy() / 4) * 3.0
warp_img = ants.from_numpy(
warp_field_np, origin=img.origin, spacing=img.spacing,
direction=img.direction, has_components=True
)
print("Successfully created a simulated warp field.")
return warp_img
def run_verification_test():
"""Main function to run the verification test."""
warp_image = create_test_image_and_warp()
tolerance = 1e-6
test_configs = [
{'name': 'Standard Deformation Gradient', 'kwargs': {'to_rotation': False, 'to_inverse_rotation': False}},
{'name': 'Rotation Matrix', 'kwargs': {'to_rotation': True, 'to_inverse_rotation': False}},
{'name': 'Inverse Rotation Matrix', 'kwargs': {'to_rotation': False, 'to_inverse_rotation': True}}
]
for config in test_configs:
print(f"\n--- Test Case: {config['name']} ---")
# Run original, slow implementation
print("Running original (slow) implementation...")
start_time_orig = time.time()
result_orig = deformation_gradient_original_py(warp_image, **config['kwargs'])
time_orig = time.time() - start_time_orig
print(f"Original finished in {time_orig:.4f} seconds.")
# Run new, optimized implementation
print("Running optimized implementation...")
start_time_opt = time.time()
result_opt = deformation_gradient_optimized(warp_image, **config['kwargs'])
time_opt = time.time() - start_time_opt
print(f"Optimized finished in {time_opt:.4f} seconds.")
# Compare results and assert correctness
are_close = np.allclose(result_orig.numpy(), result_opt.numpy(), atol=tolerance)
speedup = time_orig / time_opt if time_opt > 0 else float('inf')
print(f"\nVerification: {'PASSED' if are_close else 'FAILED'}")
print(f"Performance Gain: {speedup:.2f}x")
assert are_close, f"Mismatch found in test case: {config['name']}"
print("\n\nAll tests passed successfully! The optimized implementation is correct and significantly faster.")
if __name__ == "__main__":
run_verification_test()Metadata
Metadata
Assignees
Labels
No labels