Skip to content

Commit

Permalink
refactoring of the laplace beltrami rule (#371)
Browse files Browse the repository at this point in the history
- splits into different rules for extracting the boundary, and getting
distances to src/sink masks
- those new rules now produce surface maps (for easier introspection),
which are passed along to the laplace beltrami rule
- the dentate AP src in bbhist hemi-L was problematic with existing
approach, since the PD src/sink were always closer than the AP src, so
no vertices were being labelled with AP src.
- in this version, we do things a little different:
  1. AP and PD are done independently now
  2. Instead of nearest neighbor, we use a distance threshold, which is
defined based on a signed distance transform of the mask, mapped to the
surface.
  3. Actually there is a minimum distance, and maximum distance, along
with the minimum number of vertices required in the src/sink. The
threshold is swept until the desired number is reached, to deal with
cases such as the above.

- TODO: the get_boundary_vertices script could be further improved to
perform connected components and picking the largest one, to avoid being
affected by holes or other defects..

- TODO: the laplace beltrami solver doesn't take as long to run now that
I am decimating alot, but if we want to further optimize things, I think
the step of altering the laplacian based on the boundary conditions, after it has been made a sparse matrix, is where the inefficiency lies (based on profiling). Could try to avoid setting those weights in the first place instead of setting it after the fact..

- NOTE: solve_laplace_beltrami_open_mesh() is kept as is (except for some
logging statements)
  • Loading branch information
akhanf authored Feb 10, 2025
1 parent ec4d644 commit 64546f9
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 108 deletions.
45 changes: 23 additions & 22 deletions hippunfold/workflow/rules/coords.smk
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,13 @@ def get_gm_labels(wildcards):
return lbl_list


def get_sink_labels(wildcards):
def get_src_sink_labels(wildcards):
lbl_list = " ".join(
[
str(lbl)
for lbl in config["laplace_labels"][wildcards.label][wildcards.dir]["sink"]
]
)
return lbl_list


def get_src_labels(wildcards):
lbl_list = " ".join(
[
str(lbl)
for lbl in config["laplace_labels"][wildcards.label][wildcards.dir]["src"]
for lbl in config["laplace_labels"][wildcards.label][wildcards.dir][
wildcards.srcsink
]
]
)
return lbl_list
Expand Down Expand Up @@ -100,19 +92,19 @@ def get_inputs_laplace(wildcards):
return files


rule get_sink_mask:
rule get_src_sink_mask:
input:
labelmap=get_labels_for_laplace,
params:
labels=get_sink_labels,
labels=get_src_sink_labels,
output:
mask=bids(
root=work,
datatype="coords",
suffix="mask.nii.gz",
space="corobl",
dir="{dir}",
desc="sink",
desc="{srcsink,src|sink}",
hemi="{hemi}",
label="{label}",
**inputs.subj_wildcards,
Expand All @@ -127,19 +119,28 @@ rule get_sink_mask:
"c3d {input} -background -1 -retain-labels {params} -binarize {output}"


rule get_src_mask:
rule get_src_sink_sdt:
"""calculate signed distance transform (negative inside, positive outside)"""
input:
labelmap=get_labels_for_laplace,
params:
labels=get_src_labels,
output:
mask=bids(
root=work,
datatype="coords",
suffix="mask.nii.gz",
space="corobl",
dir="{dir}",
desc="src",
desc="{srcsink}",
hemi="{hemi}",
label="{label}",
**inputs.subj_wildcards,
),
output:
sdt=bids(
root=work,
datatype="coords",
suffix="sdt.nii.gz",
space="corobl",
dir="{dir}",
desc="{srcsink}",
hemi="{hemi}",
label="{label}",
**inputs.subj_wildcards,
Expand All @@ -151,7 +152,7 @@ rule get_src_mask:
group:
"subj"
shell:
"c3d {input} -background -1 -retain-labels {params} -binarize {output}"
"c3d {input} -sdt -o {output}"


rule get_nan_mask:
Expand Down
119 changes: 108 additions & 11 deletions hippunfold/workflow/rules/native_surf.smk
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ rule gen_native_mesh:
params:
threshold=lambda wildcards: surf_thresholds[wildcards.surfname],
decimate_opts={
"reduction": 0.5,
"reduction": 0.9,
"feature_angle": 25,
"preserve_topology": True,
},
Expand Down Expand Up @@ -173,7 +173,7 @@ rule smooth_surface:
# --- creating unfold surface from native anatomical, including post-processing


rule laplace_beltrami:
rule get_boundary_vertices:
input:
surf_gii=bids(
root=root,
Expand All @@ -184,25 +184,122 @@ rule laplace_beltrami:
label="{label}",
**inputs.subj_wildcards,
),
seg=get_labels_for_laplace,
params:
srcsink_labels=lambda wildcards: config["laplace_labels"][wildcards.label],
output:
coords_AP=bids(
label_gii=bids(
root=work,
datatype="surf",
suffix="boundary.label.gii",
space="corobl",
hemi="{hemi}",
label="{label}",
**inputs.subj_wildcards,
),
group:
"subj"
container:
config["singularity"]["autotop"]
conda:
"../envs/pyvista.yaml"
script:
"../scripts/get_boundary_vertices.py"


rule map_src_sink_sdt_to_surf:
""" Maps the distance to src/sink mask """
input:
surf_gii=bids(
root=root,
datatype="surf",
suffix="midthickness.surf.gii",
space="corobl",
hemi="{hemi}",
label="{label}",
**inputs.subj_wildcards,
),
sdt=bids(
root=work,
datatype="coords",
dir="AP",
suffix="coords.shape.gii",
desc="laplace",
suffix="sdt.nii.gz",
space="corobl",
dir="{dir}",
desc="{srcsink}",
hemi="{hemi}",
label="{label}",
**inputs.subj_wildcards,
),
coords_PD=bids(
output:
sdt=bids(
root=work,
datatype="surf",
suffix="sdt.shape.gii",
space="corobl",
hemi="{hemi}",
dir="{dir}",
desc="{srcsink}",
label="{label}",
**inputs.subj_wildcards,
),
container:
config["singularity"]["autotop"]
conda:
"../envs/workbench.yaml"
group:
"subj"
shell:
"wb_command -volume-to-surface-mapping {input.sdt} {input.surf_gii} {output.sdt} -trilinear"


rule laplace_beltrami:
input:
surf_gii=bids(
root=root,
datatype="surf",
suffix="midthickness.surf.gii",
space="corobl",
hemi="{hemi}",
label="{label}",
**inputs.subj_wildcards,
),
src_sdt=bids(
root=work,
datatype="surf",
suffix="sdt.shape.gii",
space="corobl",
hemi="{hemi}",
dir="{dir}",
desc="src",
label="{label}",
**inputs.subj_wildcards,
),
sink_sdt=bids(
root=work,
datatype="surf",
suffix="sdt.shape.gii",
space="corobl",
hemi="{hemi}",
dir="{dir}",
desc="sink",
label="{label}",
**inputs.subj_wildcards,
),
boundary=bids(
root=work,
datatype="surf",
suffix="boundary.label.gii",
space="corobl",
hemi="{hemi}",
label="{label}",
**inputs.subj_wildcards,
),
params:
min_dist_threshold=0.3,
max_dist_threshold=1,
min_terminal_vertices=5,
output:
coords=bids(
root=work,
datatype="coords",
dir="PD",
dir="{dir}",
suffix="coords.shape.gii",
desc="laplace",
space="corobl",
Expand Down
97 changes: 97 additions & 0 deletions hippunfold/workflow/scripts/get_boundary_vertices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pyvista as pv
import numpy as np
import nibabel as nib
import nibabel.gifti as gifti
from collections import defaultdict
from lib.utils import setup_logger

# Setup logger
log_file = snakemake.log[0] if snakemake.log else None
logger = setup_logger(log_file)


def find_boundary_vertices(mesh):
"""
Find boundary vertices of a 3D mesh.
Args:
mesh
Returns:
list: List of vertex indices that are boundary vertices, sorted in ascending order.
"""
vertices = mesh.points
faces = mesh.faces.reshape((-1, 4))[:, 1:4] # Extract triangle indices

edge_count = defaultdict(int)
# Step 1: Count edge occurrences
for face in faces:
# Extract edges from the face, ensure consistent ordering (min, max)
edges = [
tuple(sorted((face[0], face[1]))),
tuple(sorted((face[1], face[2]))),
tuple(sorted((face[2], face[0]))),
]
for edge in edges:
edge_count[edge] += 1
# Step 2: Identify boundary edges
boundary_edges = [edge for edge, count in edge_count.items() if count == 1]
# Step 3: Collect boundary vertices
boundary_vertices = set()
for edge in boundary_edges:
boundary_vertices.update(edge)
# Convert the set to a sorted list (array)
return np.array(sorted(boundary_vertices), dtype=np.int32)


def read_surface_from_gifti(surf_gii):
"""Load a surface mesh from a GIFTI file."""
surf = nib.load(surf_gii)
vertices = surf.agg_data("NIFTI_INTENT_POINTSET")
faces = surf.agg_data("NIFTI_INTENT_TRIANGLE")
faces_pv = np.hstack([np.full((faces.shape[0], 1), 3), faces]) # PyVista format

return pv.PolyData(vertices, faces_pv)


logger.info("Loading surface from GIFTI...")
surface = read_surface_from_gifti(snakemake.input.surf_gii)
logger.info(f"Surface loaded: {surface.n_points} vertices, {surface.n_faces} faces.")


logger.info("Find boundary vertices")
boundary_indices = find_boundary_vertices(surface)

boundary_scalars = np.zeros(surface.n_points, dtype=np.int32) # Default is 0
boundary_scalars[boundary_indices] = 1 # Set boundary vertices to 1
logger.info(
f"Boundary scalar array created. {np.sum(boundary_scalars)} boundary vertices marked."
)

logger.info("Saving GIFTI label file...")

# Create a GIFTI label data array
gii_data = gifti.GiftiDataArray(boundary_scalars, intent="NIFTI_INTENT_LABEL")

# Create a Label Table (LUT)
label_table = gifti.GiftiLabelTable()

# Define Background label (key 0)
background_label = gifti.GiftiLabel(
key=0, red=1.0, green=1.0, blue=1.0, alpha=0.0
) # Transparent
background_label.label = "Background"
label_table.labels.append(background_label)

# Define Boundary label (key 1)
boundary_label = gifti.GiftiLabel(
key=1, red=1.0, green=0.0, blue=0.0, alpha=1.0
) # Red color
boundary_label.label = "Boundary"
label_table.labels.append(boundary_label)

# Assign label table to GIFTI image
gii_img = gifti.GiftiImage(darrays=[gii_data], labeltable=label_table)

# Save the label file
gii_img.to_filename(snakemake.output.label_gii)
logger.info(f"GIFTI label file saved as '{snakemake.output.label_gii}'.")
Loading

0 comments on commit 64546f9

Please sign in to comment.