From e62ce251d2e5a71485af539c1dda9768f655a0b6 Mon Sep 17 00:00:00 2001 From: Jordan DeKraker Date: Tue, 18 Feb 2025 12:26:34 -0500 Subject: [PATCH] orthogonalize AD/PD src/sinks --- hippunfold/workflow/rules/native_surf.smk | 92 +++++++++-- .../workflow/scripts/get_boundary_vertices.py | 10 ++ .../workflow/scripts/laplace_beltrami.py | 153 +----------------- .../scripts/postproc_boundary_vertices.py | 61 +++++++ 4 files changed, 152 insertions(+), 164 deletions(-) create mode 100644 hippunfold/workflow/scripts/postproc_boundary_vertices.py diff --git a/hippunfold/workflow/rules/native_surf.smk b/hippunfold/workflow/rules/native_surf.smk index 6d9471ae..db5acafe 100644 --- a/hippunfold/workflow/rules/native_surf.smk +++ b/hippunfold/workflow/rules/native_surf.smk @@ -250,40 +250,54 @@ rule map_src_sink_sdt_to_surf: "wb_command -volume-to-surface-mapping {input.sdt} {input.surf_gii} {output.sdt} -trilinear" -rule laplace_beltrami: +rule joint_smooth_ap_pd_edges: + """ ensures non-overlapping and full labelling of AP/PD edges """ input: - surf_gii=bids( - root=root, + ap_src=bids( + root=work, datatype="surf", - suffix="midthickness.surf.gii", + suffix="sdt.shape.gii", space="corobl", hemi="{hemi}", + dir="AP", + desc="src", label="{label}", **inputs.subj_wildcards, ), - src_sdt=bids( + ap_sink=bids( root=work, datatype="surf", suffix="sdt.shape.gii", space="corobl", hemi="{hemi}", - dir="{dir}", + dir="AP", + desc="sink", + label="{label}", + **inputs.subj_wildcards, + ), + pd_src=bids( + root=work, + datatype="surf", + suffix="sdt.shape.gii", + space="corobl", + hemi="{hemi}", + dir="PD", desc="src", label="{label}", **inputs.subj_wildcards, ), - sink_sdt=bids( + pd_sink=bids( root=work, datatype="surf", suffix="sdt.shape.gii", space="corobl", hemi="{hemi}", - dir="{dir}", + dir="PD", desc="sink", label="{label}", **inputs.subj_wildcards, ), - boundary=bids( + edges=bids( root=work, datatype="surf", suffix="boundary.label.gii", @@ -293,11 +307,61 @@ rule laplace_beltrami: **inputs.subj_wildcards, ), params: - min_dist_percentile=1, - max_dist_percentile=10, - min_terminal_vertices=lambda wildcards: 5 if wildcards.dir == "AP" else 100, #TODO, instead of # of vertices, we should compute the total length of the segment - threshold_method=lambda wildcards: ( - "percentile" if wildcards.dir == "AP" else "firstminima" + min_terminal_vertices=5, # min number of vertices per src/sink + output: + ap=bids( + root=work, + datatype="surf", + suffix="mask.label.gii", + space="corobl", + hemi="{hemi}", + dir="AP", + desc="srcsink", + label="{label}", + **inputs.subj_wildcards, + ), + pd=bids( + root=work, + datatype="surf", + suffix="mask.label.gii", + space="corobl", + hemi="{hemi}", + dir="PD", + desc="srcsink", + label="{label}", + **inputs.subj_wildcards, + ), + container: + config["singularity"]["autotop"] + conda: + "../envs/pyvista.yaml" + group: + "subj" + script: + "../scripts/postproc_boundary_vertices.py" + + +rule laplace_beltrami: + input: + surf_gii=bids( + root=root, + datatype="surf", + suffix="midthickness.surf.gii", + space="corobl", + hemi="{hemi}", + label="{label}", + **inputs.subj_wildcards, + ), + src_sink_mask=bids( + root=work, + datatype="surf", + suffix="mask.label.gii", + space="corobl", + hemi="{hemi}", + dir="{dir}", + desc="srcsink", + label="{label}", + **inputs.subj_wildcards, ), output: coords=bids( diff --git a/hippunfold/workflow/scripts/get_boundary_vertices.py b/hippunfold/workflow/scripts/get_boundary_vertices.py index 1e5101d3..e7654ee3 100644 --- a/hippunfold/workflow/scripts/get_boundary_vertices.py +++ b/hippunfold/workflow/scripts/get_boundary_vertices.py @@ -67,6 +67,16 @@ def read_surface_from_gifti(surf_gii): f"Boundary scalar array created. {np.sum(boundary_scalars)} boundary vertices marked." ) +# Find the largest connected component within this sub-mesh +sub_mesh = pv.PolyData(surface.points, surface.faces).extract_points( + boundary_scalars.astype(bool), adjacent_cells=True +) +largest_component_indices = sub_mesh.connectivity(largest=True).point_data["RegionId"] +boundary_scalars = np.zeros(surface.n_points, dtype=int) +boundary_scalars[largest_component.point_data["vtkOriginalPointIds"]] = 1 +logger.info("Applying largest connected components") + + logger.info("Saving GIFTI label file...") # Create a GIFTI label data array diff --git a/hippunfold/workflow/scripts/laplace_beltrami.py b/hippunfold/workflow/scripts/laplace_beltrami.py index ed953c3b..b1b2b8b8 100644 --- a/hippunfold/workflow/scripts/laplace_beltrami.py +++ b/hippunfold/workflow/scripts/laplace_beltrami.py @@ -3,124 +3,6 @@ from scipy.sparse import diags, linalg, lil_matrix from lib.utils import setup_logger -log_file = snakemake.log[0] if snakemake.log else None -logger = setup_logger(log_file) - -import numpy as np - -import numpy as np -from scipy.signal import argrelextrema - - -def get_terminal_indices_firstminima( - sdt, min_vertices, boundary_mask, bins=100, smoothing_window=5 -): - """ - Gets the terminal (src/sink) vertex indices by determining an adaptive threshold - using the first local minimum of the histogram of `sdt` values. - - Parameters: - - sdt: Signed distance transform array. - - min_vertices: The minimum number of vertices required. - - boundary_mask: Boolean or binary mask indicating boundary regions. - - bins: Number of bins to use in the histogram (default: 100). - - smoothing_window: Window size for moving average smoothing (default: 5). - - Returns: - - indices: List of terminal vertex indices. - - Raises: - - ValueError: If the minimum number of vertices cannot be found. - """ - - # Extract SDT values within the boundary mask - sdt_values = sdt[boundary_mask == 1] - - # Compute histogram - hist, bin_edges = np.histogram(sdt_values, bins=bins, density=True) - bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 - - # Smooth the histogram using a simple moving average - smoothed_hist = np.convolve( - hist, np.ones(smoothing_window) / smoothing_window, mode="same" - ) - - # Find local minima - minima_indices = argrelextrema(smoothed_hist, np.less)[0] - - if len(minima_indices) == 0: - raise ValueError("No local minima found in the histogram.") - - # Select the first local minimum after the first peak - first_minimum_bin = bin_centers[minima_indices[0]] - - # Select indices where SDT is below this threshold - indices = np.where((sdt < first_minimum_bin) & (boundary_mask == 1))[0].tolist() - - if len(indices) >= min_vertices: - return indices - - raise ValueError( - f"Unable to find minimum of {min_vertices} vertices on boundary within the first local minimum threshold." - ) - - -def get_terminal_indices_percentile( - sdt, min_percentile, max_percentile, min_vertices, boundary_mask -): - """ - Gets the terminal (src/sink) vertex indices by sweeping a percentile-based threshold - of the signed distance transform (sdt), ensuring at least `min_vertices` are selected. - - Instead of a fixed distance range, this function dynamically determines the threshold - by scanning from `min_percentile` to `max_percentile`. - - Parameters: - - sdt: Signed distance transform array. - - min_percentile: Starting percentile for thresholding (0-100). - - max_percentile: Maximum percentile for thresholding (0-100). - - min_vertices: The minimum number of vertices required. - - boundary_mask: Boolean or binary mask indicating boundary regions. - - Returns: - - indices: List of terminal vertex indices. - - Raises: - - ValueError: If the minimum number of vertices cannot be found. - """ - - for percentile in np.arange(min_percentile, max_percentile, 0.5): - dist_threshold = np.percentile(sdt[boundary_mask == 1], percentile) - indices = np.where((sdt < dist_threshold) & (boundary_mask == 1))[0].tolist() - - if len(indices) >= min_vertices: - logger.info( - f"Using {percentile}-th percentile to obtain sdt threshold of {dist_threshold}, with {len(indices)} vertices" - ) - return indices - - raise ValueError( - f"Unable to find minimum of {min_vertices} vertices on boundary within the {max_percentile}th percentile of distances" - ) - - -def get_terminal_indices_threshold( - sdt, min_dist, max_dist, min_vertices, boundary_mask -): - """ - Gets the terminal (src/sink) vertex indices based on distance to the src/sink mask, - a boundary mask, and a minumum number of vertices. The distance from the mask is - swept from min_dist to max_dist, until the min_vertices is achieved, else an - exception is thrown.""" - - for dist in np.linspace(min_dist, max_dist, 20): - indices = np.where((sdt < dist) & (boundary_mask == 1))[0].tolist() - if len(indices) >= min_vertices: - return indices - raise ValueError( - f"Unable to find minimum of {min_vertices} vertices on boundary, within {max_dist}mm of the terminal mask" - ) - def solve_laplace_beltrami_open_mesh(vertices, faces, boundary_conditions=None): """ @@ -211,38 +93,9 @@ def solve_laplace_beltrami_open_mesh(vertices, faces, boundary_conditions=None): vertices = surf.agg_data("NIFTI_INTENT_POINTSET") faces = surf.agg_data("NIFTI_INTENT_TRIANGLE") -boundary_mask = nib.load(snakemake.input.boundary).agg_data() -src_sdt = nib.load(snakemake.input.src_sdt).agg_data() -sink_sdt = nib.load(snakemake.input.sink_sdt).agg_data() - -if snakemake.params.threshold_method == "percentile": - src_indices = get_terminal_indices_percentile( - src_sdt, - snakemake.params.min_dist_percentile, - snakemake.params.max_dist_percentile, - snakemake.params.min_terminal_vertices, - boundary_mask, - ) - sink_indices = get_terminal_indices_percentile( - sink_sdt, - snakemake.params.min_dist_percentile, - snakemake.params.max_dist_percentile, - snakemake.params.min_terminal_vertices, - boundary_mask, - ) - - -elif snakemake.params.threshold_method == "firstminima": - src_indices = get_terminal_indices_firstminima( - src_sdt, - snakemake.params.min_terminal_vertices, - boundary_mask, - ) - sink_indices = get_terminal_indices_firstminima( - sink_sdt, - snakemake.params.min_terminal_vertices, - boundary_mask, - ) +src_sink_mask = nib.load(snakemake.input.src_sink_mask).agg_data() +src_indices = np.where(src_sink_mask == 1)[0] +sink_indices = np.where(src_sink_mask == 2)[0] logger.info(f"# of src boundary vertices: {len(src_indices)}") diff --git a/hippunfold/workflow/scripts/postproc_boundary_vertices.py b/hippunfold/workflow/scripts/postproc_boundary_vertices.py new file mode 100644 index 00000000..b2a9cdf3 --- /dev/null +++ b/hippunfold/workflow/scripts/postproc_boundary_vertices.py @@ -0,0 +1,61 @@ +import pyvista as pv +import numpy as np +import nibabel as nib +import nibabel.gifti as gifti +from collections import Counter +from scipy.signal import argrelextrema + +nmin = snakemake.params.min_terminal_vertices + +logger.info("Loading surface from GIFTI...") +edges = nib.load(snakemake.input.edges).agg_data() +ap_src = nib.load(snakemake.input.ap_src).agg_data() +ap_sink = nib.load(snakemake.input.ap_sink).agg_data() +pd_src = nib.load(snakemake.input.pd_src).agg_data() +pd_sink = nib.load(snakemake.input.pd_sink).agg_data() + +distances = np.vstack( + (ap_src[edges == 1], ap_sink[edges == 1], pd_src[edges == 1], pd_sink[edges == 1]) +) +num_vertices = distances.shape[0] + + +logger.info("Assigning labels apsrc, apsink, pdsrc, pdsink") + +max_iterations = 10 # Prevent infinite loops +for _ in range(max_iterations): + # Scale distances + scaled_distances = distances * scaling_factors + + # Assign labels based on min scaled distance + labels = np.argmin(scaled_distances, axis=1) + + # Count occurrences per label + unique, counts = np.unique(labels, return_counts=True) + label_counts = { + k: counts[i] if k in unique else 0 for i, k in enumerate(range(num_labels)) + } + + # Check if all labels meet nmin + if all(count >= nmin for count in label_counts.values()): + break # Stop if all labels are sufficiently represented + + # Update scaling factors for underrepresented labels + for k in range(num_labels): + if label_counts[k] < nmin: + scaling_factors[k] *= 1.5 # Increase competitiveness of the label + +# Ensure all labels are represented +logger.info("Final label counts:", label_counts) + +ap_srcsink = np.zeros((len(edges))) +ap_srcsink[edges == 1][labels == 0] = 1 +ap_srcsink[edges == 1][labels == 1] = 2 +gii_img = gifti.GiftiImage(darrays=[ap_srcsink]) +nib.save(gii_img, snakemake.outputs.ap) + +pd_srcsink = np.zeros((len(edges))) +pd_srcsink[edges == 1][labels == 2] = 1 +pd_srcsink[edges == 1][labels == 3] = 2 +gii_img = gifti.GiftiImage(darrays=[pd_srcsink]) +nib.save(gii_img, snakemake.outputs.pd)