Skip to content

Commit 816c894

Browse files
committed
fix: robustify displacements/fieldmap conversions
1 parent dbec0d9 commit 816c894

File tree

4 files changed

+167
-89
lines changed

4 files changed

+167
-89
lines changed

sdcflows/interfaces/fmap.py

+36
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,39 @@ def _run_interface(self, runtime):
152152
self._results["out_file"]
153153
)
154154
return runtime
155+
156+
157+
class _DisplacementsField2FieldmapInputSpec(BaseInterfaceInputSpec):
158+
transform = File(exists=True, mandatory=True, desc="input displacements field")
159+
ro_time = traits.Float(mandatory=True, desc="total readout time")
160+
pe_dir = traits.Enum(
161+
"j-", "j", "i", "i-", "k", "k-", mandatory=True, desc="phase encoding direction"
162+
)
163+
itk_transform = traits.Bool(
164+
True, usedefault=True, desc="whether this is an ITK/ANTs transform"
165+
)
166+
167+
168+
class _DisplacementsField2FieldmapOutputSpec(TraitedSpec):
169+
out_file = File(exists=True, desc="output fieldmap in Hz")
170+
171+
172+
class DisplacementsField2Fieldmap(SimpleInterface):
173+
"""Convert from a transform to a B0 fieldmap in Hz."""
174+
175+
input_spec = _DisplacementsField2FieldmapInputSpec
176+
output_spec = _DisplacementsField2FieldmapOutputSpec
177+
178+
def _run_interface(self, runtime):
179+
from sdcflows.transform import disp_to_fmap
180+
181+
self._results["out_file"] = fname_presuffix(
182+
self.inputs.in_file, suffix="_Hz", newpath=runtime.cwd
183+
)
184+
disp_to_fmap(
185+
nb.load(self.inputs.transform),
186+
ro_time=self.inputs.ro_time,
187+
pe_dir=self.inputs.pe_dir,
188+
itk_format=self.inputs.itk_transform,
189+
).to_filename(self._results["out_file"])
190+
return runtime

sdcflows/tests/test_transform.py

+23
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,26 @@ def test_displacements_field(tmpdir, testdata_dir, outdir, pe_dir, rotation, fli
134134
f"_y-{rotation[1] or 0}_z-{rotation[2] or 0}.svg"
135135
),
136136
).run()
137+
138+
139+
@pytest.mark.parametrize("pe_dir", ["j", "j-", "i", "i-", "k", "k-"])
140+
def test_conversions(tmpdir, testdata_dir, pe_dir):
141+
"""Check idempotency."""
142+
tmpdir.chdir()
143+
144+
fmap_nii = nb.load(testdata_dir / "topup-field.nii.gz")
145+
new_nii = tf.disp_to_fmap(
146+
tf.fmap_to_disp(
147+
fmap_nii,
148+
ro_time=0.2,
149+
pe_dir=pe_dir,
150+
),
151+
ro_time=0.2,
152+
pe_dir=pe_dir,
153+
)
154+
155+
new_nii.to_filename("test.nii.gz")
156+
assert np.allclose(
157+
fmap_nii.get_fdata(dtype="float32"),
158+
new_nii.get_fdata(dtype="float32"),
159+
)

sdcflows/transform.py

+102-31
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def fit(self, spatialimage):
8989

9090
# Interpolate the VSM (voxel-shift map)
9191
vsm = np.zeros(spatialimage.shape[:3], dtype="float32")
92-
vsm = (np.squeeze(np.vstack(coeffs).T) @ sparse_vstack(weights)).reshape(
92+
vsm = (np.squeeze(np.hstack(coeffs).T) @ sparse_vstack(weights)).reshape(
9393
vsm.shape
9494
)
9595

@@ -215,36 +215,107 @@ def to_displacements(self, ro_time, pe_dir, itk_format=True):
215215
A NIfTI 1.0 object containing the distortion.
216216
217217
"""
218-
# Set polarity & scale VSM (voxel-shift-map) by readout time
219-
vsm = self.shifts.get_fdata().copy()
220-
pe_axis = "ijk".index(pe_dir[0])
221-
vsm *= -1.0 if pe_dir.endswith("-") else 1.0
222-
vsm *= ro_time
223-
224-
# Shape of displacements field
225-
# Note that ITK NIfTI fields are 5D (have an empty 4th dimension)
226-
fieldshape = tuple(list(vsm.shape[:3]) + [1, 3])
227-
228-
# Convert VSM to voxel displacements
229-
ijk_deltas = np.zeros((vsm.size, 3), dtype="float32")
230-
ijk_deltas[:, pe_axis] = vsm.reshape(-1)
231-
232-
# To convert from VSM to RAS field we just apply the affine
233-
aff = self.shifts.affine.copy()
234-
aff[:3, 3] = 0 # Translations MUST NOT be applied, though.
235-
xyz_deltas = nb.affines.apply_affine(aff, ijk_deltas)
236-
if itk_format:
237-
# ITK displacement vectors are in LPS orientation
238-
xyz_deltas[..., (0, 1)] *= -1.0
239-
240-
xyz_nii = nb.Nifti1Image(
241-
xyz_deltas.reshape(fieldshape),
242-
self.shifts.affine,
243-
None,
244-
)
245-
xyz_nii.header.set_intent("vector", (), "")
246-
xyz_nii.header.set_xyzt_units("mm")
247-
return xyz_nii
218+
return fmap_to_disp(self.shifts, ro_time, pe_dir, itk_format=itk_format)
219+
220+
221+
def fmap_to_disp(fmap_nii, ro_time, pe_dir, itk_format=True):
222+
"""
223+
Convert a fieldmap in Hz into an ITK/ANTs-compatible displacements field.
224+
225+
The displacements field can be calculated following
226+
`Eq. (2) in the fieldmap fitting section
227+
<sdcflows.workflows.fit.fieldmap.html#mjx-eqn-eq%3Afieldmap-2>`__.
228+
229+
Parameters
230+
----------
231+
fmap_nii : :obj:`os.pathlike`
232+
Path to a voxel-shift-map (VSM) in NIfTI format
233+
ro_time : :obj:`float`
234+
The total readout time in seconds (only if ``vsm=False``).
235+
pe_dir : :obj:`str`
236+
The ``PhaseEncodingDirection`` metadata value (only if ``vsm=False``).
237+
238+
Returns
239+
-------
240+
spatialimage : :obj:`nibabel.nifti.Nifti1Image`
241+
A NIfTI 1.0 object containing the distortion.
242+
243+
"""
244+
# Set polarity & scale VSM (voxel-shift-map) by readout time
245+
vsm = fmap_nii.get_fdata().copy() * (-ro_time if pe_dir.endswith("-") else ro_time)
246+
247+
# Shape of displacements field
248+
# Note that ITK NIfTI fields are 5D (have an empty 4th dimension)
249+
fieldshape = tuple(list(vsm.shape[:3]) + [1, 3])
250+
251+
# Convert VSM to voxel displacements
252+
ijk_deltas = np.zeros((vsm.size, 3), dtype="float32")
253+
ijk_deltas[:, "ijk".index(pe_dir[0])] = vsm.reshape(-1)
254+
255+
# To convert from VSM to RAS field we just apply the affine
256+
aff = fmap_nii.affine.copy()
257+
aff[:3, 3] = 0 # Translations MUST NOT be applied, though.
258+
xyz_deltas = nb.affines.apply_affine(aff, ijk_deltas)
259+
if itk_format:
260+
# ITK displacement vectors are in LPS orientation
261+
xyz_deltas[..., (0, 1)] *= -1.0
262+
263+
xyz_nii = nb.Nifti1Image(
264+
xyz_deltas.reshape(fieldshape),
265+
fmap_nii.affine,
266+
None,
267+
)
268+
xyz_nii.header.set_intent("vector", (), "")
269+
xyz_nii.header.set_xyzt_units("mm")
270+
return xyz_nii
271+
272+
273+
def disp_to_fmap(xyz_nii, ro_time, pe_dir, itk_format=True):
274+
"""
275+
Convert a displacements field into a fieldmap in Hz.
276+
277+
This is the dual operation to the previous function.
278+
279+
Parameters
280+
----------
281+
xyz_nii : :obj:`os.pathlike`
282+
Path to a displacements field in NIfTI format.
283+
ro_time : :obj:`float`
284+
The total readout time in seconds (only if ``vsm=False``).
285+
pe_dir : :obj:`str`
286+
The ``PhaseEncodingDirection`` metadata value (only if ``vsm=False``).
287+
288+
Returns
289+
-------
290+
spatialimage : :obj:`nibabel.nifti.Nifti1Image`
291+
A NIfTI 1.0 object containing the field in Hz.
292+
293+
"""
294+
xyz_deltas = np.squeeze(xyz_nii.get_fdata(dtype="float32")).reshape((-1, 3))
295+
296+
if itk_format:
297+
# ITK displacement vectors are in LPS orientation
298+
xyz_deltas[:, (0, 1)] *= -1
299+
300+
inv_aff = np.linalg.inv(xyz_nii.affine)
301+
inv_aff[:3, 3] = 0 # Translations MUST NOT be applied.
302+
303+
# Convert displacements from mm to voxel units
304+
# Using the inverse affine accounts for reordering of axes, etc.
305+
ijk_deltas = nb.affines.apply_affine(inv_aff, xyz_deltas).astype("float32")
306+
ijk_deltas = (
307+
ijk_deltas[:, "ijk".index(pe_dir[0])]
308+
* (-1.0 if pe_dir.endswith("-") else 1.0)
309+
/ ro_time
310+
)
311+
312+
ijk_nii = nb.Nifti1Image(
313+
ijk_deltas.reshape(xyz_nii.shape[:3]),
314+
xyz_nii.affine,
315+
None,
316+
)
317+
ijk_nii.header.set_xyzt_units("mm")
318+
return ijk_nii
248319

249320

250321
def _cubic_bspline(d):

sdcflows/workflows/fit/syn.py

+6-58
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def init_syn_sdc_wf(
173173
)
174174
from ...utils.misc import front as _pop, last as _pull
175175
from ...interfaces.epi import GetReadoutTime
176+
from ...interfaces.fmap import DisplacementsField2Fieldmap
176177
from ...interfaces.bspline import (
177178
ApplyCoeffsField,
178179
BSplineApprox,
@@ -288,7 +289,7 @@ def init_syn_sdc_wf(
288289
unwarp = pe.Node(ApplyCoeffsField(), name="unwarp")
289290

290291
# Extract nonzero component
291-
extract_field = pe.Node(niu.Function(function=_extract_field), name="extract_field")
292+
extract_field = pe.Node(DisplacementsField2Fieldmap(), name="extract_field")
292293

293294
# Check zooms (avoid very expensive B-Splines fitting)
294295
zooms_field = pe.Node(
@@ -316,7 +317,6 @@ def init_syn_sdc_wf(
316317
workflow.connect([
317318
(inputnode, readout_time, [(("epi_ref", _pop), "in_file"),
318319
(("epi_ref", _pull), "metadata")]),
319-
(inputnode, extract_field, [("epi_ref", "epi_meta")]),
320320
(inputnode, atlas_msk, [("sd_prior", "in_file")]),
321321
(inputnode, clip_epi, [(("epi_ref", _pop), "in_file")]),
322322
(inputnode, unwarp, [(("epi_ref", _pop), "in_data")]),
@@ -351,8 +351,10 @@ def init_syn_sdc_wf(
351351
(fixed_masks, syn, [("out", "fixed_image_masks")]),
352352
(epi_merge, syn, [("out", "moving_image")]),
353353
(moving_masks, syn, [("out", "moving_image_masks")]),
354-
(syn, extract_field, [("forward_transforms", "in_file")]),
355-
(extract_field, zooms_field, [("out", "input_image")]),
354+
(syn, extract_field, [(("forward_transforms", _pop), "transform")]),
355+
(readout_time, extract_field, [("readout_time", "ro_time"),
356+
("pe_direction", "pe_dir")]),
357+
(extract_field, zooms_field, [("out_file", "input_image")]),
356358
(zooms_field, zooms_bmask, [("output_image", "reference_image")]),
357359
(zooms_field, bs_filter, [("output_image", "in_data")]),
358360
(zooms_bmask, bs_filter, [("output_image", "in_mask")]),
@@ -631,60 +633,6 @@ def _warp_dir(intuple, nlevels=3):
631633
return nlevels * [[1 if pe == ax else 0.1 for ax in "ijk"]]
632634

633635

634-
def _extract_field(in_file, epi_meta, in_mask=None, demean=True):
635-
"""
636-
Extract the nonzero component of the deformation field estimated by ANTs.
637-
638-
Examples
639-
--------
640-
>>> nii = nb.load(
641-
... _extract_field(
642-
... ["field.nii.gz"],
643-
... ("epi.nii.gz", {"PhaseEncodingDirection": "j-", "TotalReadoutTime": 0.005}),
644-
... demean=False,
645-
... )
646-
... )
647-
>>> nii.shape
648-
(10, 10, 10)
649-
650-
>>> np.allclose(nii.get_fdata(), -200)
651-
True
652-
653-
"""
654-
from pathlib import Path
655-
from nipype.utils.filemanip import fname_presuffix
656-
import numpy as np
657-
import nibabel as nb
658-
from sdcflows.utils.epimanip import get_trt
659-
660-
fieldnii = nb.load(in_file[0])
661-
trt = get_trt(epi_meta[1], in_file=epi_meta[0])
662-
data = (
663-
np.squeeze(fieldnii.get_fdata(dtype="float32"))[
664-
..., "ijk".index(epi_meta[1]["PhaseEncodingDirection"][0])
665-
]
666-
/ trt
667-
* (-1.0 if epi_meta[1]["PhaseEncodingDirection"].endswith("-") else 1.0)
668-
)
669-
670-
if ["PhaseEncodingDirection"][0] in "ij":
671-
data *= -1.0 # ITK/ANTs is an LPS system, flip direction
672-
673-
if demean:
674-
mask = (
675-
np.ones_like(data, dtype=bool) if in_mask is None
676-
else np.asanyarray(nb.load(in_mask).dataobj, dtype=bool)
677-
)
678-
# De-mean the result
679-
data -= np.median(data[mask])
680-
681-
out_file = Path(fname_presuffix(Path(in_file[0]).name, suffix="_fieldmap"))
682-
nii = nb.Nifti1Image(data, fieldnii.affine, None)
683-
nii.header.set_xyzt_units(fieldnii.header.get_xyzt_units()[0])
684-
nii.to_filename(out_file)
685-
return str(out_file.absolute())
686-
687-
688636
def _merge_meta(epi_ref, meta_list):
689637
"""Prepare a tuple of EPI reference and metadata."""
690638
return (epi_ref, meta_list[0])

0 commit comments

Comments
 (0)