Skip to content

Commit

Permalink
orthogonalize AD/PD src/sinks
Browse files Browse the repository at this point in the history
  • Loading branch information
Jordan DeKraker committed Feb 18, 2025
1 parent 32b66fe commit e62ce25
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 164 deletions.
92 changes: 78 additions & 14 deletions hippunfold/workflow/rules/native_surf.smk
Original file line number Diff line number Diff line change
Expand Up @@ -250,40 +250,54 @@ rule map_src_sink_sdt_to_surf:
"wb_command -volume-to-surface-mapping {input.sdt} {input.surf_gii} {output.sdt} -trilinear"


rule laplace_beltrami:
rule joint_smooth_ap_pd_edges:
""" ensures non-overlapping and full labelling of AP/PD edges """
input:
surf_gii=bids(
root=root,
ap_src=bids(
root=work,
datatype="surf",
suffix="midthickness.surf.gii",
suffix="sdt.shape.gii",
space="corobl",
hemi="{hemi}",
dir="AP",
desc="src",
label="{label}",
**inputs.subj_wildcards,
),
src_sdt=bids(
ap_sink=bids(
root=work,
datatype="surf",
suffix="sdt.shape.gii",
space="corobl",
hemi="{hemi}",
dir="{dir}",
dir="AP",
desc="sink",
label="{label}",
**inputs.subj_wildcards,
),
pd_src=bids(
root=work,
datatype="surf",
suffix="sdt.shape.gii",
space="corobl",
hemi="{hemi}",
dir="PD",
desc="src",
label="{label}",
**inputs.subj_wildcards,
),
sink_sdt=bids(
pd_sink=bids(
root=work,
datatype="surf",
suffix="sdt.shape.gii",
space="corobl",
hemi="{hemi}",
dir="{dir}",
dir="PD",
desc="sink",
label="{label}",
**inputs.subj_wildcards,
),
boundary=bids(
edges=bids(
root=work,
datatype="surf",
suffix="boundary.label.gii",
Expand All @@ -293,11 +307,61 @@ rule laplace_beltrami:
**inputs.subj_wildcards,
),
params:
min_dist_percentile=1,
max_dist_percentile=10,
min_terminal_vertices=lambda wildcards: 5 if wildcards.dir == "AP" else 100, #TODO, instead of # of vertices, we should compute the total length of the segment
threshold_method=lambda wildcards: (
"percentile" if wildcards.dir == "AP" else "firstminima"
min_terminal_vertices=5, # min number of vertices per src/sink
output:
ap=bids(
root=work,
datatype="surf",
suffix="mask.label.gii",
space="corobl",
hemi="{hemi}",
dir="AP",
desc="srcsink",
label="{label}",
**inputs.subj_wildcards,
),
pd=bids(
root=work,
datatype="surf",
suffix="mask.label.gii",
space="corobl",
hemi="{hemi}",
dir="PD",
desc="srcsink",
label="{label}",
**inputs.subj_wildcards,
),
container:
config["singularity"]["autotop"]
conda:
"../envs/pyvista.yaml"
group:
"subj"
script:
"../scripts/postproc_boundary_vertices.py"


rule laplace_beltrami:
input:
surf_gii=bids(
root=root,
datatype="surf",
suffix="midthickness.surf.gii",
space="corobl",
hemi="{hemi}",
label="{label}",
**inputs.subj_wildcards,
),
src_sink_mask=bids(
root=work,
datatype="surf",
suffix="mask.label.gii",
space="corobl",
hemi="{hemi}",
dir="{dir}",
desc="srcsink",
label="{label}",
**inputs.subj_wildcards,
),
output:
coords=bids(
Expand Down
10 changes: 10 additions & 0 deletions hippunfold/workflow/scripts/get_boundary_vertices.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def read_surface_from_gifti(surf_gii):
f"Boundary scalar array created. {np.sum(boundary_scalars)} boundary vertices marked."
)

# Find the largest connected component within this sub-mesh
sub_mesh = pv.PolyData(surface.points, surface.faces).extract_points(
boundary_scalars.astype(bool), adjacent_cells=True
)
largest_component_indices = sub_mesh.connectivity(largest=True).point_data["RegionId"]
boundary_scalars = np.zeros(surface.n_points, dtype=int)
boundary_scalars[largest_component.point_data["vtkOriginalPointIds"]] = 1
logger.info("Applying largest connected components")


logger.info("Saving GIFTI label file...")

# Create a GIFTI label data array
Expand Down
153 changes: 3 additions & 150 deletions hippunfold/workflow/scripts/laplace_beltrami.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,124 +3,6 @@
from scipy.sparse import diags, linalg, lil_matrix
from lib.utils import setup_logger

log_file = snakemake.log[0] if snakemake.log else None
logger = setup_logger(log_file)

import numpy as np

import numpy as np
from scipy.signal import argrelextrema


def get_terminal_indices_firstminima(
sdt, min_vertices, boundary_mask, bins=100, smoothing_window=5
):
"""
Gets the terminal (src/sink) vertex indices by determining an adaptive threshold
using the first local minimum of the histogram of `sdt` values.
Parameters:
- sdt: Signed distance transform array.
- min_vertices: The minimum number of vertices required.
- boundary_mask: Boolean or binary mask indicating boundary regions.
- bins: Number of bins to use in the histogram (default: 100).
- smoothing_window: Window size for moving average smoothing (default: 5).
Returns:
- indices: List of terminal vertex indices.
Raises:
- ValueError: If the minimum number of vertices cannot be found.
"""

# Extract SDT values within the boundary mask
sdt_values = sdt[boundary_mask == 1]

# Compute histogram
hist, bin_edges = np.histogram(sdt_values, bins=bins, density=True)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

# Smooth the histogram using a simple moving average
smoothed_hist = np.convolve(
hist, np.ones(smoothing_window) / smoothing_window, mode="same"
)

# Find local minima
minima_indices = argrelextrema(smoothed_hist, np.less)[0]

if len(minima_indices) == 0:
raise ValueError("No local minima found in the histogram.")

# Select the first local minimum after the first peak
first_minimum_bin = bin_centers[minima_indices[0]]

# Select indices where SDT is below this threshold
indices = np.where((sdt < first_minimum_bin) & (boundary_mask == 1))[0].tolist()

if len(indices) >= min_vertices:
return indices

raise ValueError(
f"Unable to find minimum of {min_vertices} vertices on boundary within the first local minimum threshold."
)


def get_terminal_indices_percentile(
sdt, min_percentile, max_percentile, min_vertices, boundary_mask
):
"""
Gets the terminal (src/sink) vertex indices by sweeping a percentile-based threshold
of the signed distance transform (sdt), ensuring at least `min_vertices` are selected.
Instead of a fixed distance range, this function dynamically determines the threshold
by scanning from `min_percentile` to `max_percentile`.
Parameters:
- sdt: Signed distance transform array.
- min_percentile: Starting percentile for thresholding (0-100).
- max_percentile: Maximum percentile for thresholding (0-100).
- min_vertices: The minimum number of vertices required.
- boundary_mask: Boolean or binary mask indicating boundary regions.
Returns:
- indices: List of terminal vertex indices.
Raises:
- ValueError: If the minimum number of vertices cannot be found.
"""

for percentile in np.arange(min_percentile, max_percentile, 0.5):
dist_threshold = np.percentile(sdt[boundary_mask == 1], percentile)
indices = np.where((sdt < dist_threshold) & (boundary_mask == 1))[0].tolist()

if len(indices) >= min_vertices:
logger.info(
f"Using {percentile}-th percentile to obtain sdt threshold of {dist_threshold}, with {len(indices)} vertices"
)
return indices

raise ValueError(
f"Unable to find minimum of {min_vertices} vertices on boundary within the {max_percentile}th percentile of distances"
)


def get_terminal_indices_threshold(
sdt, min_dist, max_dist, min_vertices, boundary_mask
):
"""
Gets the terminal (src/sink) vertex indices based on distance to the src/sink mask,
a boundary mask, and a minumum number of vertices. The distance from the mask is
swept from min_dist to max_dist, until the min_vertices is achieved, else an
exception is thrown."""

for dist in np.linspace(min_dist, max_dist, 20):
indices = np.where((sdt < dist) & (boundary_mask == 1))[0].tolist()
if len(indices) >= min_vertices:
return indices
raise ValueError(
f"Unable to find minimum of {min_vertices} vertices on boundary, within {max_dist}mm of the terminal mask"
)


def solve_laplace_beltrami_open_mesh(vertices, faces, boundary_conditions=None):
"""
Expand Down Expand Up @@ -211,38 +93,9 @@ def solve_laplace_beltrami_open_mesh(vertices, faces, boundary_conditions=None):
vertices = surf.agg_data("NIFTI_INTENT_POINTSET")
faces = surf.agg_data("NIFTI_INTENT_TRIANGLE")

boundary_mask = nib.load(snakemake.input.boundary).agg_data()
src_sdt = nib.load(snakemake.input.src_sdt).agg_data()
sink_sdt = nib.load(snakemake.input.sink_sdt).agg_data()

if snakemake.params.threshold_method == "percentile":
src_indices = get_terminal_indices_percentile(
src_sdt,
snakemake.params.min_dist_percentile,
snakemake.params.max_dist_percentile,
snakemake.params.min_terminal_vertices,
boundary_mask,
)
sink_indices = get_terminal_indices_percentile(
sink_sdt,
snakemake.params.min_dist_percentile,
snakemake.params.max_dist_percentile,
snakemake.params.min_terminal_vertices,
boundary_mask,
)


elif snakemake.params.threshold_method == "firstminima":
src_indices = get_terminal_indices_firstminima(
src_sdt,
snakemake.params.min_terminal_vertices,
boundary_mask,
)
sink_indices = get_terminal_indices_firstminima(
sink_sdt,
snakemake.params.min_terminal_vertices,
boundary_mask,
)
src_sink_mask = nib.load(snakemake.input.src_sink_mask).agg_data()
src_indices = np.where(src_sink_mask == 1)[0]
sink_indices = np.where(src_sink_mask == 2)[0]


logger.info(f"# of src boundary vertices: {len(src_indices)}")
Expand Down
61 changes: 61 additions & 0 deletions hippunfold/workflow/scripts/postproc_boundary_vertices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pyvista as pv
import numpy as np
import nibabel as nib
import nibabel.gifti as gifti
from collections import Counter
from scipy.signal import argrelextrema

nmin = snakemake.params.min_terminal_vertices

logger.info("Loading surface from GIFTI...")
edges = nib.load(snakemake.input.edges).agg_data()
ap_src = nib.load(snakemake.input.ap_src).agg_data()
ap_sink = nib.load(snakemake.input.ap_sink).agg_data()
pd_src = nib.load(snakemake.input.pd_src).agg_data()
pd_sink = nib.load(snakemake.input.pd_sink).agg_data()

distances = np.vstack(
(ap_src[edges == 1], ap_sink[edges == 1], pd_src[edges == 1], pd_sink[edges == 1])
)
num_vertices = distances.shape[0]


logger.info("Assigning labels apsrc, apsink, pdsrc, pdsink")

max_iterations = 10 # Prevent infinite loops
for _ in range(max_iterations):
# Scale distances
scaled_distances = distances * scaling_factors

# Assign labels based on min scaled distance
labels = np.argmin(scaled_distances, axis=1)

# Count occurrences per label
unique, counts = np.unique(labels, return_counts=True)
label_counts = {
k: counts[i] if k in unique else 0 for i, k in enumerate(range(num_labels))
}

# Check if all labels meet nmin
if all(count >= nmin for count in label_counts.values()):
break # Stop if all labels are sufficiently represented

# Update scaling factors for underrepresented labels
for k in range(num_labels):
if label_counts[k] < nmin:
scaling_factors[k] *= 1.5 # Increase competitiveness of the label

# Ensure all labels are represented
logger.info("Final label counts:", label_counts)

ap_srcsink = np.zeros((len(edges)))
ap_srcsink[edges == 1][labels == 0] = 1
ap_srcsink[edges == 1][labels == 1] = 2
gii_img = gifti.GiftiImage(darrays=[ap_srcsink])
nib.save(gii_img, snakemake.outputs.ap)

pd_srcsink = np.zeros((len(edges)))
pd_srcsink[edges == 1][labels == 2] = 1
pd_srcsink[edges == 1][labels == 3] = 2
gii_img = gifti.GiftiImage(darrays=[pd_srcsink])
nib.save(gii_img, snakemake.outputs.pd)

0 comments on commit e62ce25

Please sign in to comment.