Skip to content

Commit 64546f9

Browse files
authored
refactoring of the laplace beltrami rule (#371)
- splits into different rules for extracting the boundary, and getting distances to src/sink masks - those new rules now produce surface maps (for easier introspection), which are passed along to the laplace beltrami rule - the dentate AP src in bbhist hemi-L was problematic with existing approach, since the PD src/sink were always closer than the AP src, so no vertices were being labelled with AP src. - in this version, we do things a little different: 1. AP and PD are done independently now 2. Instead of nearest neighbor, we use a distance threshold, which is defined based on a signed distance transform of the mask, mapped to the surface. 3. Actually there is a minimum distance, and maximum distance, along with the minimum number of vertices required in the src/sink. The threshold is swept until the desired number is reached, to deal with cases such as the above. - TODO: the get_boundary_vertices script could be further improved to perform connected components and picking the largest one, to avoid being affected by holes or other defects.. - TODO: the laplace beltrami solver doesn't take as long to run now that I am decimating alot, but if we want to further optimize things, I think the step of altering the laplacian based on the boundary conditions, after it has been made a sparse matrix, is where the inefficiency lies (based on profiling). Could try to avoid setting those weights in the first place instead of setting it after the fact.. - NOTE: solve_laplace_beltrami_open_mesh() is kept as is (except for some logging statements)
1 parent ec4d644 commit 64546f9

File tree

4 files changed

+285
-108
lines changed

4 files changed

+285
-108
lines changed

hippunfold/workflow/rules/coords.smk

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,13 @@ def get_gm_labels(wildcards):
2121
return lbl_list
2222

2323

24-
def get_sink_labels(wildcards):
24+
def get_src_sink_labels(wildcards):
2525
lbl_list = " ".join(
2626
[
2727
str(lbl)
28-
for lbl in config["laplace_labels"][wildcards.label][wildcards.dir]["sink"]
29-
]
30-
)
31-
return lbl_list
32-
33-
34-
def get_src_labels(wildcards):
35-
lbl_list = " ".join(
36-
[
37-
str(lbl)
38-
for lbl in config["laplace_labels"][wildcards.label][wildcards.dir]["src"]
28+
for lbl in config["laplace_labels"][wildcards.label][wildcards.dir][
29+
wildcards.srcsink
30+
]
3931
]
4032
)
4133
return lbl_list
@@ -100,19 +92,19 @@ def get_inputs_laplace(wildcards):
10092
return files
10193

10294

103-
rule get_sink_mask:
95+
rule get_src_sink_mask:
10496
input:
10597
labelmap=get_labels_for_laplace,
10698
params:
107-
labels=get_sink_labels,
99+
labels=get_src_sink_labels,
108100
output:
109101
mask=bids(
110102
root=work,
111103
datatype="coords",
112104
suffix="mask.nii.gz",
113105
space="corobl",
114106
dir="{dir}",
115-
desc="sink",
107+
desc="{srcsink,src|sink}",
116108
hemi="{hemi}",
117109
label="{label}",
118110
**inputs.subj_wildcards,
@@ -127,19 +119,28 @@ rule get_sink_mask:
127119
"c3d {input} -background -1 -retain-labels {params} -binarize {output}"
128120

129121

130-
rule get_src_mask:
122+
rule get_src_sink_sdt:
123+
"""calculate signed distance transform (negative inside, positive outside)"""
131124
input:
132-
labelmap=get_labels_for_laplace,
133-
params:
134-
labels=get_src_labels,
135-
output:
136125
mask=bids(
137126
root=work,
138127
datatype="coords",
139128
suffix="mask.nii.gz",
140129
space="corobl",
141130
dir="{dir}",
142-
desc="src",
131+
desc="{srcsink}",
132+
hemi="{hemi}",
133+
label="{label}",
134+
**inputs.subj_wildcards,
135+
),
136+
output:
137+
sdt=bids(
138+
root=work,
139+
datatype="coords",
140+
suffix="sdt.nii.gz",
141+
space="corobl",
142+
dir="{dir}",
143+
desc="{srcsink}",
143144
hemi="{hemi}",
144145
label="{label}",
145146
**inputs.subj_wildcards,
@@ -151,7 +152,7 @@ rule get_src_mask:
151152
group:
152153
"subj"
153154
shell:
154-
"c3d {input} -background -1 -retain-labels {params} -binarize {output}"
155+
"c3d {input} -sdt -o {output}"
155156

156157

157158
rule get_nan_mask:

hippunfold/workflow/rules/native_surf.smk

Lines changed: 108 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ rule gen_native_mesh:
6666
params:
6767
threshold=lambda wildcards: surf_thresholds[wildcards.surfname],
6868
decimate_opts={
69-
"reduction": 0.5,
69+
"reduction": 0.9,
7070
"feature_angle": 25,
7171
"preserve_topology": True,
7272
},
@@ -173,7 +173,7 @@ rule smooth_surface:
173173
# --- creating unfold surface from native anatomical, including post-processing
174174

175175

176-
rule laplace_beltrami:
176+
rule get_boundary_vertices:
177177
input:
178178
surf_gii=bids(
179179
root=root,
@@ -184,25 +184,122 @@ rule laplace_beltrami:
184184
label="{label}",
185185
**inputs.subj_wildcards,
186186
),
187-
seg=get_labels_for_laplace,
188-
params:
189-
srcsink_labels=lambda wildcards: config["laplace_labels"][wildcards.label],
190187
output:
191-
coords_AP=bids(
188+
label_gii=bids(
189+
root=work,
190+
datatype="surf",
191+
suffix="boundary.label.gii",
192+
space="corobl",
193+
hemi="{hemi}",
194+
label="{label}",
195+
**inputs.subj_wildcards,
196+
),
197+
group:
198+
"subj"
199+
container:
200+
config["singularity"]["autotop"]
201+
conda:
202+
"../envs/pyvista.yaml"
203+
script:
204+
"../scripts/get_boundary_vertices.py"
205+
206+
207+
rule map_src_sink_sdt_to_surf:
208+
""" Maps the distance to src/sink mask """
209+
input:
210+
surf_gii=bids(
211+
root=root,
212+
datatype="surf",
213+
suffix="midthickness.surf.gii",
214+
space="corobl",
215+
hemi="{hemi}",
216+
label="{label}",
217+
**inputs.subj_wildcards,
218+
),
219+
sdt=bids(
192220
root=work,
193221
datatype="coords",
194-
dir="AP",
195-
suffix="coords.shape.gii",
196-
desc="laplace",
222+
suffix="sdt.nii.gz",
197223
space="corobl",
224+
dir="{dir}",
225+
desc="{srcsink}",
198226
hemi="{hemi}",
199227
label="{label}",
200228
**inputs.subj_wildcards,
201229
),
202-
coords_PD=bids(
230+
output:
231+
sdt=bids(
232+
root=work,
233+
datatype="surf",
234+
suffix="sdt.shape.gii",
235+
space="corobl",
236+
hemi="{hemi}",
237+
dir="{dir}",
238+
desc="{srcsink}",
239+
label="{label}",
240+
**inputs.subj_wildcards,
241+
),
242+
container:
243+
config["singularity"]["autotop"]
244+
conda:
245+
"../envs/workbench.yaml"
246+
group:
247+
"subj"
248+
shell:
249+
"wb_command -volume-to-surface-mapping {input.sdt} {input.surf_gii} {output.sdt} -trilinear"
250+
251+
252+
rule laplace_beltrami:
253+
input:
254+
surf_gii=bids(
255+
root=root,
256+
datatype="surf",
257+
suffix="midthickness.surf.gii",
258+
space="corobl",
259+
hemi="{hemi}",
260+
label="{label}",
261+
**inputs.subj_wildcards,
262+
),
263+
src_sdt=bids(
264+
root=work,
265+
datatype="surf",
266+
suffix="sdt.shape.gii",
267+
space="corobl",
268+
hemi="{hemi}",
269+
dir="{dir}",
270+
desc="src",
271+
label="{label}",
272+
**inputs.subj_wildcards,
273+
),
274+
sink_sdt=bids(
275+
root=work,
276+
datatype="surf",
277+
suffix="sdt.shape.gii",
278+
space="corobl",
279+
hemi="{hemi}",
280+
dir="{dir}",
281+
desc="sink",
282+
label="{label}",
283+
**inputs.subj_wildcards,
284+
),
285+
boundary=bids(
286+
root=work,
287+
datatype="surf",
288+
suffix="boundary.label.gii",
289+
space="corobl",
290+
hemi="{hemi}",
291+
label="{label}",
292+
**inputs.subj_wildcards,
293+
),
294+
params:
295+
min_dist_threshold=0.3,
296+
max_dist_threshold=1,
297+
min_terminal_vertices=5,
298+
output:
299+
coords=bids(
203300
root=work,
204301
datatype="coords",
205-
dir="PD",
302+
dir="{dir}",
206303
suffix="coords.shape.gii",
207304
desc="laplace",
208305
space="corobl",
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import pyvista as pv
2+
import numpy as np
3+
import nibabel as nib
4+
import nibabel.gifti as gifti
5+
from collections import defaultdict
6+
from lib.utils import setup_logger
7+
8+
# Setup logger
9+
log_file = snakemake.log[0] if snakemake.log else None
10+
logger = setup_logger(log_file)
11+
12+
13+
def find_boundary_vertices(mesh):
14+
"""
15+
Find boundary vertices of a 3D mesh.
16+
17+
Args:
18+
mesh
19+
Returns:
20+
list: List of vertex indices that are boundary vertices, sorted in ascending order.
21+
"""
22+
vertices = mesh.points
23+
faces = mesh.faces.reshape((-1, 4))[:, 1:4] # Extract triangle indices
24+
25+
edge_count = defaultdict(int)
26+
# Step 1: Count edge occurrences
27+
for face in faces:
28+
# Extract edges from the face, ensure consistent ordering (min, max)
29+
edges = [
30+
tuple(sorted((face[0], face[1]))),
31+
tuple(sorted((face[1], face[2]))),
32+
tuple(sorted((face[2], face[0]))),
33+
]
34+
for edge in edges:
35+
edge_count[edge] += 1
36+
# Step 2: Identify boundary edges
37+
boundary_edges = [edge for edge, count in edge_count.items() if count == 1]
38+
# Step 3: Collect boundary vertices
39+
boundary_vertices = set()
40+
for edge in boundary_edges:
41+
boundary_vertices.update(edge)
42+
# Convert the set to a sorted list (array)
43+
return np.array(sorted(boundary_vertices), dtype=np.int32)
44+
45+
46+
def read_surface_from_gifti(surf_gii):
47+
"""Load a surface mesh from a GIFTI file."""
48+
surf = nib.load(surf_gii)
49+
vertices = surf.agg_data("NIFTI_INTENT_POINTSET")
50+
faces = surf.agg_data("NIFTI_INTENT_TRIANGLE")
51+
faces_pv = np.hstack([np.full((faces.shape[0], 1), 3), faces]) # PyVista format
52+
53+
return pv.PolyData(vertices, faces_pv)
54+
55+
56+
logger.info("Loading surface from GIFTI...")
57+
surface = read_surface_from_gifti(snakemake.input.surf_gii)
58+
logger.info(f"Surface loaded: {surface.n_points} vertices, {surface.n_faces} faces.")
59+
60+
61+
logger.info("Find boundary vertices")
62+
boundary_indices = find_boundary_vertices(surface)
63+
64+
boundary_scalars = np.zeros(surface.n_points, dtype=np.int32) # Default is 0
65+
boundary_scalars[boundary_indices] = 1 # Set boundary vertices to 1
66+
logger.info(
67+
f"Boundary scalar array created. {np.sum(boundary_scalars)} boundary vertices marked."
68+
)
69+
70+
logger.info("Saving GIFTI label file...")
71+
72+
# Create a GIFTI label data array
73+
gii_data = gifti.GiftiDataArray(boundary_scalars, intent="NIFTI_INTENT_LABEL")
74+
75+
# Create a Label Table (LUT)
76+
label_table = gifti.GiftiLabelTable()
77+
78+
# Define Background label (key 0)
79+
background_label = gifti.GiftiLabel(
80+
key=0, red=1.0, green=1.0, blue=1.0, alpha=0.0
81+
) # Transparent
82+
background_label.label = "Background"
83+
label_table.labels.append(background_label)
84+
85+
# Define Boundary label (key 1)
86+
boundary_label = gifti.GiftiLabel(
87+
key=1, red=1.0, green=0.0, blue=0.0, alpha=1.0
88+
) # Red color
89+
boundary_label.label = "Boundary"
90+
label_table.labels.append(boundary_label)
91+
92+
# Assign label table to GIFTI image
93+
gii_img = gifti.GiftiImage(darrays=[gii_data], labeltable=label_table)
94+
95+
# Save the label file
96+
gii_img.to_filename(snakemake.output.label_gii)
97+
logger.info(f"GIFTI label file saved as '{snakemake.output.label_gii}'.")

0 commit comments

Comments
 (0)