From 3d0da8eb40c30b31a2e7aea093566b5465c98ae5 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 30 Sep 2021 10:27:49 +0200 Subject: [PATCH] fix: make anatomical the reference in SyN registration Resolves: #85. Context: nipreps/fmriprep#2530 --- sdcflows/workflows/fit/syn.py | 129 +++++++++++++++++++++++----------- 1 file changed, 87 insertions(+), 42 deletions(-) diff --git a/sdcflows/workflows/fit/syn.py b/sdcflows/workflows/fit/syn.py index 8e1d16118b..416f02d30a 100644 --- a/sdcflows/workflows/fit/syn.py +++ b/sdcflows/workflows/fit/syn.py @@ -168,7 +168,6 @@ def init_syn_sdc_wf( from niworkflows.interfaces.header import CopyXForm from niworkflows.interfaces.nibabel import Binarize, RegridToZooms from ...utils.misc import front as _pop - from ...interfaces.utils import Reoblique from ...interfaces.bspline import ( BSplineApprox, DEFAULT_LF_ZOOMS_MM, @@ -211,16 +210,10 @@ def init_syn_sdc_wf( atlas_msk = pe.Node(Binarize(thresh_low=atlas_threshold), name="atlas_msk") anat_dilmsk = pe.Node(BinaryDilation(), name="anat_dilmsk") - epi_dilmsk = pe.Node(BinaryDilation(), name="epi_dilmsk") amask2epi = pe.Node( ApplyTransforms(interpolation="MultiLabel", transforms="identity"), name="amask2epi", ) - prior2epi = pe.Node( - ApplyTransforms(interpolation="MultiLabel", transforms="identity"), - name="prior2epi", - ) - prior_dilmsk = pe.Node(BinaryDilation(radius=4), name="prior_dilmsk") epi_umask = pe.Node(Union(), name="epi_umask") moving_masks = pe.Node( @@ -228,6 +221,8 @@ def init_syn_sdc_wf( name="moving_masks", run_without_submitting=True, ) + moving_masks.inputs.in1 = "NULL" + moving_masks.inputs.in2 = "NULL" fixed_masks = pe.Node( niu.Merge(3), @@ -235,14 +230,15 @@ def init_syn_sdc_wf( mem_gb=DEFAULT_MEMORY_MIN_GB, run_without_submitting=True, ) - deoblique = pe.Node(CopyXForm(fields=["epi_ref"]), name="deoblique") - reoblique = pe.Node(Reoblique(), name="reoblique") # Set a manageable size for the epi reference find_zooms = pe.Node(niu.Function(function=_adjust_zooms), name="find_zooms") zooms_epi = pe.Node(RegridToZooms(), name="zooms_epi") + histmatch = pe.Node(niu.Function(function=match_histogram), + name="histmatch") + # SyN Registration Core syn = pe.Node( Registration( @@ -258,7 +254,7 @@ def init_syn_sdc_wf( syn.inputs.args = "--write-interval-volumes 5" unwarp_ref = pe.Node( - ApplyTransforms(interpolation="BSpline"), + ApplyTransforms(interpolation="BSpline", args="-u float"), name="unwarp_ref", ) @@ -267,11 +263,11 @@ def init_syn_sdc_wf( # Check zooms (avoid very expensive B-Splines fitting) zooms_field = pe.Node( - ApplyTransforms(interpolation="BSpline", transforms="identity"), + ApplyTransforms(interpolation="BSpline", transforms="identity", args="-u float"), name="zooms_field", ) zooms_bmask = pe.Node( - ApplyTransforms(interpolation="MultiLabel", transforms="identity"), + ApplyTransforms(interpolation="MultiLabel", transforms="identity", args="-u uchar"), name="zooms_bmask", ) @@ -285,51 +281,46 @@ def init_syn_sdc_wf( # fmt: off workflow.connect([ - (inputnode, extract_field, [("epi_ref", "epi_meta")]), + (inputnode, extract_field, [("epi_ref", "epi_meta"), + ("anat_mask", "in_mask")]), (inputnode, atlas_msk, [("sd_prior", "in_file")]), (inputnode, deoblique, [(("epi_ref", _pop), "epi_ref"), ("epi_mask", "hdr_file")]), - (inputnode, reoblique, [(("epi_ref", _pop), "in_epi")]), - (inputnode, epi_dilmsk, [("epi_mask", "in_file")]), + (inputnode, amask2epi, [("epi_mask", "reference_image")]), (inputnode, zooms_bmask, [("anat_mask", "input_image")]), - (inputnode, fixed_masks, [("anat_mask", "in2")]), + (inputnode, fixed_masks, [("anat_mask", "in1"), + ("anat_mask", "in2")]), (inputnode, anat_dilmsk, [("anat_mask", "in_file")]), (inputnode, warp_dir, [("epi_ref", "intuple")]), - (inputnode, syn, [("anat_ref", "moving_image")]), - (epi_dilmsk, prior2epi, [("out_file", "reference_image")]), - (atlas_msk, prior2epi, [("out_file", "input_image")]), - (prior2epi, prior_dilmsk, [("output_image", "in_file")]), - (anat_dilmsk, fixed_masks, [("out_file", "in1")]), + (inputnode, histmatch, [("anat_ref", "reference")]), + (inputnode, syn, [("anat_ref", "fixed_image")]), + (inputnode, epi_umask, [("epi_mask", "in1")]), + # (zooms_bmask, moving_masks, [("output_image", "in3")]), + (anat_dilmsk, histmatch, [("out_file", "ref_mask")]), (warp_dir, syn, [("out", "restrict_deformation")]), (inputnode, find_zooms, [("anat_ref", "in_anat"), (("epi_ref", _pop), "in_epi")]), + (deoblique, histmatch, [("epi_ref", "image")]), (deoblique, zooms_epi, [("epi_ref", "in_file")]), (deoblique, unwarp_ref, [("epi_ref", "input_image")]), (find_zooms, zooms_epi, [("out", "zooms")]), (zooms_epi, unwarp_ref, [("out_file", "reference_image")]), (atlas_msk, fixed_masks, [("out_mask", "in3")]), - (fixed_masks, syn, [("out", "moving_image_masks")]), - (epi_dilmsk, epi_umask, [("out_file", "in1")]), - (epi_dilmsk, amask2epi, [("out_file", "reference_image")]), + (fixed_masks, syn, [("out", "fixed_image_masks")]), (anat_dilmsk, amask2epi, [("out_file", "input_image")]), (amask2epi, epi_umask, [("output_image", "in2")]), - (epi_umask, moving_masks, [("out_file", "in1")]), - (prior_dilmsk, moving_masks, [("out_file", "in2")]), - (prior2epi, moving_masks, [("output_image", "in3")]), - (moving_masks, syn, [("out", "fixed_image_masks")]), - (deoblique, syn, [("epi_ref", "fixed_image")]), - (syn, extract_field, [("reverse_transforms", "in_file")]), - (syn, unwarp_ref, [("reverse_transforms", "transforms")]), - (unwarp_ref, zooms_bmask, [("output_image", "reference_image")]), + (epi_umask, histmatch, [("out_file", "img_mask")]), + # (moving_masks, syn, [("out", "moving_image_masks")]), + (histmatch, syn, [("out", "moving_image")]), + (syn, extract_field, [("forward_transforms", "in_file")]), + (syn, unwarp_ref, [("forward_transforms", "transforms")]), (unwarp_ref, zooms_field, [("output_image", "reference_image")]), (extract_field, zooms_field, [("out", "input_image")]), - (unwarp_ref, reoblique, [("output_image", "in_plumb")]), - (zooms_field, reoblique, [("output_image", "in_field")]), - (zooms_bmask, reoblique, [("output_image", "in_mask")]), - (reoblique, bs_filter, [("out_field", "in_data"), - ("out_mask", "in_mask")]), - (reoblique, outputnode, [("out_epi", "fmap_ref"), - ("out_mask", "fmap_mask")]), + (zooms_field, zooms_bmask, [("output_image", "reference_image")]), + (zooms_field, bs_filter, [("output_image", "in_data")]), + (zooms_bmask, bs_filter, [("output_image", "in_mask")]), + (unwarp_ref, outputnode, [("output_image", "fmap_ref")]), + (zooms_bmask, outputnode, [("output_image", "fmap_mask")]), (bs_filter, outputnode, [ ("out_extrapolated" if not debug else "out_field", "fmap"), ("out_coeff", "fmap_coeff")]), @@ -492,7 +483,6 @@ def init_syn_preprocessing_wf( name="merge_output", run_without_submitting=True, ) - mask_anat = pe.Node(ApplyMask(), name="mask_anat") clip_anat = pe.Node( IntensityClip(p_min=0.0, p_max=99.8, invert=t1w_inversion), name="clip_anat" @@ -514,6 +504,13 @@ def init_syn_preprocessing_wf( IntensityClip(p_min=0.0, p_max=100), name="clip_anat_final" ) + def _remove_first_mask(in_file): + if not isinstance(in_file, list): + in_file = [in_file] + + in_file.insert(0, "NULL") + return in_file + anat_dilmsk = pe.Node(BinaryDilation(), name="anat_dilmsk") epi_dilmsk = pe.Node(BinaryDilation(), name="epi_dilmsk") @@ -536,7 +533,8 @@ def init_syn_preprocessing_wf( (ref_anat, epi2anat, [("output_image", "fixed_image")]), (anat_dilmsk, epi2anat, [("out_file", "fixed_image_masks")]), (deob_epi, epi2anat, [("out_file", "moving_image")]), - (epi_dilmsk, epi2anat, [("out_file", "moving_image_masks")]), + (epi_dilmsk, epi2anat, [ + (("out_file", _remove_first_mask), "moving_image_masks")]), (deob_epi, sampling_ref, [("out_file", "fixed_image")]), (epi2anat, transform_list, [("forward_transforms", "in1")]), (transform_list, prior2epi, [("out", "transforms")]), @@ -615,7 +613,7 @@ def _warp_dir(intuple, nlevels=3): return nlevels * [[int(pe == ax) for ax in "ijk"]] -def _extract_field(in_file, epi_meta): +def _extract_field(in_file, epi_meta, in_mask=None): """ Extract the nonzero component of the deformation field estimated by ANTs. @@ -648,6 +646,15 @@ def _extract_field(in_file, epi_meta): / trt * (-1.0 if epi_meta[1]["PhaseEncodingDirection"].endswith("-") else 1.0) ) + + if in_mask is None: + mask = np.ones_like(data, dtype=bool) + else: + mask = np.asanyarray(nb.load(in_mask).dataobj, dtype=bool) + + # De-mean the result + data -= np.median(data[mask]) + out_file = Path(fname_presuffix(Path(in_file[0]).name, suffix="_fieldmap")) nii = nb.Nifti1Image(data, fieldnii.affine, None) nii.header.set_xyzt_units(fieldnii.header.get_xyzt_units()[0]) @@ -688,3 +695,41 @@ def _adjust_zooms(in_anat, in_epi, z_max=2.2, z_min=1.8): z_max, ) return tuple([zoom_iso] * 3) + + +def match_histogram(reference, image, ref_mask=None, img_mask=None): + """Match the histogram of the T2-like anatomical with the EPI.""" + import os + import numpy as np + import nibabel as nb + from nipype.utils.filemanip import fname_presuffix + from skimage.exposure import match_histograms + + nii_img = nb.load(image) + img_data = np.asanyarray(nii_img.dataobj) + ref_data = np.asanyarray(nb.load(reference).dataobj) + + ref_mask = ( + np.ones_like(ref_data, dtype=bool) if ref_mask is None + else np.asanyarray(nb.load(ref_mask).dataobj) > 0 + ) + + img_mask = ( + np.ones_like(img_data, dtype=bool) if img_mask is None + else np.asanyarray(nb.load(img_mask).dataobj) > 0 + ) + + out_file = fname_presuffix( + image, suffix="_matched", newpath=os.getcwd() + ) + img_data[img_mask] = match_histograms( + img_data[img_mask], + ref_data[ref_mask], + ) + + nii_img.__class__( + img_data, + nii_img.affine, + nii_img.header, + ).to_filename(out_file) + return out_file