Skip to content

Commit

Permalink
fix: make anatomical the reference in SyN registration
Browse files Browse the repository at this point in the history
Resolves: #85.
Context: nipreps/fmriprep#2530
  • Loading branch information
oesteban committed Oct 1, 2021
1 parent de05b84 commit 3d0da8e
Showing 1 changed file with 87 additions and 42 deletions.
129 changes: 87 additions & 42 deletions sdcflows/workflows/fit/syn.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def init_syn_sdc_wf(
from niworkflows.interfaces.header import CopyXForm
from niworkflows.interfaces.nibabel import Binarize, RegridToZooms
from ...utils.misc import front as _pop
from ...interfaces.utils import Reoblique
from ...interfaces.bspline import (
BSplineApprox,
DEFAULT_LF_ZOOMS_MM,
Expand Down Expand Up @@ -211,38 +210,35 @@ def init_syn_sdc_wf(

atlas_msk = pe.Node(Binarize(thresh_low=atlas_threshold), name="atlas_msk")
anat_dilmsk = pe.Node(BinaryDilation(), name="anat_dilmsk")
epi_dilmsk = pe.Node(BinaryDilation(), name="epi_dilmsk")
amask2epi = pe.Node(
ApplyTransforms(interpolation="MultiLabel", transforms="identity"),
name="amask2epi",
)
prior2epi = pe.Node(
ApplyTransforms(interpolation="MultiLabel", transforms="identity"),
name="prior2epi",
)
prior_dilmsk = pe.Node(BinaryDilation(radius=4), name="prior_dilmsk")

epi_umask = pe.Node(Union(), name="epi_umask")
moving_masks = pe.Node(
niu.Merge(3),
name="moving_masks",
run_without_submitting=True,
)
moving_masks.inputs.in1 = "NULL"
moving_masks.inputs.in2 = "NULL"

fixed_masks = pe.Node(
niu.Merge(3),
name="fixed_masks",
mem_gb=DEFAULT_MEMORY_MIN_GB,
run_without_submitting=True,
)

deoblique = pe.Node(CopyXForm(fields=["epi_ref"]), name="deoblique")
reoblique = pe.Node(Reoblique(), name="reoblique")

# Set a manageable size for the epi reference
find_zooms = pe.Node(niu.Function(function=_adjust_zooms), name="find_zooms")
zooms_epi = pe.Node(RegridToZooms(), name="zooms_epi")

histmatch = pe.Node(niu.Function(function=match_histogram),
name="histmatch")

# SyN Registration Core
syn = pe.Node(
Registration(
Expand All @@ -258,7 +254,7 @@ def init_syn_sdc_wf(
syn.inputs.args = "--write-interval-volumes 5"

unwarp_ref = pe.Node(
ApplyTransforms(interpolation="BSpline"),
ApplyTransforms(interpolation="BSpline", args="-u float"),
name="unwarp_ref",
)

Expand All @@ -267,11 +263,11 @@ def init_syn_sdc_wf(

# Check zooms (avoid very expensive B-Splines fitting)
zooms_field = pe.Node(
ApplyTransforms(interpolation="BSpline", transforms="identity"),
ApplyTransforms(interpolation="BSpline", transforms="identity", args="-u float"),
name="zooms_field",
)
zooms_bmask = pe.Node(
ApplyTransforms(interpolation="MultiLabel", transforms="identity"),
ApplyTransforms(interpolation="MultiLabel", transforms="identity", args="-u uchar"),
name="zooms_bmask",
)

Expand All @@ -285,51 +281,46 @@ def init_syn_sdc_wf(

# fmt: off
workflow.connect([
(inputnode, extract_field, [("epi_ref", "epi_meta")]),
(inputnode, extract_field, [("epi_ref", "epi_meta"),
("anat_mask", "in_mask")]),
(inputnode, atlas_msk, [("sd_prior", "in_file")]),
(inputnode, deoblique, [(("epi_ref", _pop), "epi_ref"),
("epi_mask", "hdr_file")]),
(inputnode, reoblique, [(("epi_ref", _pop), "in_epi")]),
(inputnode, epi_dilmsk, [("epi_mask", "in_file")]),
(inputnode, amask2epi, [("epi_mask", "reference_image")]),
(inputnode, zooms_bmask, [("anat_mask", "input_image")]),
(inputnode, fixed_masks, [("anat_mask", "in2")]),
(inputnode, fixed_masks, [("anat_mask", "in1"),
("anat_mask", "in2")]),
(inputnode, anat_dilmsk, [("anat_mask", "in_file")]),
(inputnode, warp_dir, [("epi_ref", "intuple")]),
(inputnode, syn, [("anat_ref", "moving_image")]),
(epi_dilmsk, prior2epi, [("out_file", "reference_image")]),
(atlas_msk, prior2epi, [("out_file", "input_image")]),
(prior2epi, prior_dilmsk, [("output_image", "in_file")]),
(anat_dilmsk, fixed_masks, [("out_file", "in1")]),
(inputnode, histmatch, [("anat_ref", "reference")]),
(inputnode, syn, [("anat_ref", "fixed_image")]),
(inputnode, epi_umask, [("epi_mask", "in1")]),
# (zooms_bmask, moving_masks, [("output_image", "in3")]),
(anat_dilmsk, histmatch, [("out_file", "ref_mask")]),
(warp_dir, syn, [("out", "restrict_deformation")]),
(inputnode, find_zooms, [("anat_ref", "in_anat"),
(("epi_ref", _pop), "in_epi")]),
(deoblique, histmatch, [("epi_ref", "image")]),
(deoblique, zooms_epi, [("epi_ref", "in_file")]),
(deoblique, unwarp_ref, [("epi_ref", "input_image")]),
(find_zooms, zooms_epi, [("out", "zooms")]),
(zooms_epi, unwarp_ref, [("out_file", "reference_image")]),
(atlas_msk, fixed_masks, [("out_mask", "in3")]),
(fixed_masks, syn, [("out", "moving_image_masks")]),
(epi_dilmsk, epi_umask, [("out_file", "in1")]),
(epi_dilmsk, amask2epi, [("out_file", "reference_image")]),
(fixed_masks, syn, [("out", "fixed_image_masks")]),
(anat_dilmsk, amask2epi, [("out_file", "input_image")]),
(amask2epi, epi_umask, [("output_image", "in2")]),
(epi_umask, moving_masks, [("out_file", "in1")]),
(prior_dilmsk, moving_masks, [("out_file", "in2")]),
(prior2epi, moving_masks, [("output_image", "in3")]),
(moving_masks, syn, [("out", "fixed_image_masks")]),
(deoblique, syn, [("epi_ref", "fixed_image")]),
(syn, extract_field, [("reverse_transforms", "in_file")]),
(syn, unwarp_ref, [("reverse_transforms", "transforms")]),
(unwarp_ref, zooms_bmask, [("output_image", "reference_image")]),
(epi_umask, histmatch, [("out_file", "img_mask")]),
# (moving_masks, syn, [("out", "moving_image_masks")]),
(histmatch, syn, [("out", "moving_image")]),
(syn, extract_field, [("forward_transforms", "in_file")]),
(syn, unwarp_ref, [("forward_transforms", "transforms")]),
(unwarp_ref, zooms_field, [("output_image", "reference_image")]),
(extract_field, zooms_field, [("out", "input_image")]),
(unwarp_ref, reoblique, [("output_image", "in_plumb")]),
(zooms_field, reoblique, [("output_image", "in_field")]),
(zooms_bmask, reoblique, [("output_image", "in_mask")]),
(reoblique, bs_filter, [("out_field", "in_data"),
("out_mask", "in_mask")]),
(reoblique, outputnode, [("out_epi", "fmap_ref"),
("out_mask", "fmap_mask")]),
(zooms_field, zooms_bmask, [("output_image", "reference_image")]),
(zooms_field, bs_filter, [("output_image", "in_data")]),
(zooms_bmask, bs_filter, [("output_image", "in_mask")]),
(unwarp_ref, outputnode, [("output_image", "fmap_ref")]),
(zooms_bmask, outputnode, [("output_image", "fmap_mask")]),
(bs_filter, outputnode, [
("out_extrapolated" if not debug else "out_field", "fmap"),
("out_coeff", "fmap_coeff")]),
Expand Down Expand Up @@ -492,7 +483,6 @@ def init_syn_preprocessing_wf(
name="merge_output",
run_without_submitting=True,
)

mask_anat = pe.Node(ApplyMask(), name="mask_anat")
clip_anat = pe.Node(
IntensityClip(p_min=0.0, p_max=99.8, invert=t1w_inversion), name="clip_anat"
Expand All @@ -514,6 +504,13 @@ def init_syn_preprocessing_wf(
IntensityClip(p_min=0.0, p_max=100), name="clip_anat_final"
)

def _remove_first_mask(in_file):
if not isinstance(in_file, list):
in_file = [in_file]

in_file.insert(0, "NULL")
return in_file

anat_dilmsk = pe.Node(BinaryDilation(), name="anat_dilmsk")
epi_dilmsk = pe.Node(BinaryDilation(), name="epi_dilmsk")

Expand All @@ -536,7 +533,8 @@ def init_syn_preprocessing_wf(
(ref_anat, epi2anat, [("output_image", "fixed_image")]),
(anat_dilmsk, epi2anat, [("out_file", "fixed_image_masks")]),
(deob_epi, epi2anat, [("out_file", "moving_image")]),
(epi_dilmsk, epi2anat, [("out_file", "moving_image_masks")]),
(epi_dilmsk, epi2anat, [
(("out_file", _remove_first_mask), "moving_image_masks")]),
(deob_epi, sampling_ref, [("out_file", "fixed_image")]),
(epi2anat, transform_list, [("forward_transforms", "in1")]),
(transform_list, prior2epi, [("out", "transforms")]),
Expand Down Expand Up @@ -615,7 +613,7 @@ def _warp_dir(intuple, nlevels=3):
return nlevels * [[int(pe == ax) for ax in "ijk"]]


def _extract_field(in_file, epi_meta):
def _extract_field(in_file, epi_meta, in_mask=None):
"""
Extract the nonzero component of the deformation field estimated by ANTs.
Expand Down Expand Up @@ -648,6 +646,15 @@ def _extract_field(in_file, epi_meta):
/ trt
* (-1.0 if epi_meta[1]["PhaseEncodingDirection"].endswith("-") else 1.0)
)

if in_mask is None:
mask = np.ones_like(data, dtype=bool)
else:
mask = np.asanyarray(nb.load(in_mask).dataobj, dtype=bool)

# De-mean the result
data -= np.median(data[mask])

out_file = Path(fname_presuffix(Path(in_file[0]).name, suffix="_fieldmap"))
nii = nb.Nifti1Image(data, fieldnii.affine, None)
nii.header.set_xyzt_units(fieldnii.header.get_xyzt_units()[0])
Expand Down Expand Up @@ -688,3 +695,41 @@ def _adjust_zooms(in_anat, in_epi, z_max=2.2, z_min=1.8):
z_max,
)
return tuple([zoom_iso] * 3)


def match_histogram(reference, image, ref_mask=None, img_mask=None):
"""Match the histogram of the T2-like anatomical with the EPI."""
import os
import numpy as np
import nibabel as nb
from nipype.utils.filemanip import fname_presuffix
from skimage.exposure import match_histograms

nii_img = nb.load(image)
img_data = np.asanyarray(nii_img.dataobj)
ref_data = np.asanyarray(nb.load(reference).dataobj)

ref_mask = (
np.ones_like(ref_data, dtype=bool) if ref_mask is None
else np.asanyarray(nb.load(ref_mask).dataobj) > 0
)

img_mask = (
np.ones_like(img_data, dtype=bool) if img_mask is None
else np.asanyarray(nb.load(img_mask).dataobj) > 0
)

out_file = fname_presuffix(
image, suffix="_matched", newpath=os.getcwd()
)
img_data[img_mask] = match_histograms(
img_data[img_mask],
ref_data[ref_mask],
)

nii_img.__class__(
img_data,
nii_img.affine,
nii_img.header,
).to_filename(out_file)
return out_file

0 comments on commit 3d0da8e

Please sign in to comment.