Skip to content

Commit 935d209

Browse files
authored
Merge pull request #3280 from tclose/predicted_signal
ENH: Adds interfaces for MRtrix utils shconv and sh2amp
2 parents 5d2e224 + e096c19 commit 935d209

11 files changed

+288
-3
lines changed

Diff for: .zenodo.json

+5
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,11 @@
711711
{
712712
"name": "Marina, Ana"
713713
},
714+
{
715+
"affiliation": "University of Sydney",
716+
"name": "Close, Thomas",
717+
"orcid": "0000-0002-4160-2134"
718+
},
714719
{
715720
"name": "Davison, Andrew"
716721
},

Diff for: nipype/interfaces/mrtrix3/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
MRConvert,
1414
MRResize,
1515
DWIExtract,
16+
SHConv,
17+
SH2Amp,
1618
)
1719
from .preprocess import (
1820
ResponseSD,

Diff for: nipype/interfaces/mrtrix3/reconst.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import os.path as op
66

7-
from ..base import traits, TraitedSpec, File, Undefined, InputMultiObject
7+
from ..base import traits, TraitedSpec, File, InputMultiObject, isdefined
88
from .base import MRTrix3BaseInputSpec, MRTrix3Base
99

1010

@@ -50,10 +50,18 @@ class FitTensorInputSpec(MRTrix3BaseInputSpec):
5050
"only applies to the non-linear methods"
5151
),
5252
)
53+
predicted_signal = File(
54+
argstr="-predicted_signal %s",
55+
desc=(
56+
"specify a file to contain the predicted signal from the tensor "
57+
"fits. This can be used to calculate the residual signal"
58+
),
59+
)
5360

5461

5562
class FitTensorOutputSpec(TraitedSpec):
5663
out_file = File(exists=True, desc="the output DTI file")
64+
predicted_signal = File(desc="Predicted signal from fitted tensors")
5765

5866

5967
class FitTensor(MRTrix3Base):
@@ -81,6 +89,8 @@ class FitTensor(MRTrix3Base):
8189
def _list_outputs(self):
8290
outputs = self.output_spec().get()
8391
outputs["out_file"] = op.abspath(self.inputs.out_file)
92+
if isdefined(self.inputs.predicted_signal):
93+
outputs["predicted_signal"] = op.abspath(self.inputs.predicted_signal)
8494
return outputs
8595

8696

@@ -144,12 +154,23 @@ class EstimateFODInputSpec(MRTrix3BaseInputSpec):
144154
"[ az el ] pairs for the directions."
145155
),
146156
)
157+
predicted_signal = File(
158+
argstr="-predicted_signal %s",
159+
desc=(
160+
"specify a file to contain the predicted signal from the FOD "
161+
"estimates. This can be used to calculate the residual signal."
162+
"Note that this is only valid if algorithm == 'msmt_csd'. "
163+
"For single shell reconstructions use a combination of SHConv "
164+
"and SH2Amp instead."
165+
),
166+
)
147167

148168

149169
class EstimateFODOutputSpec(TraitedSpec):
150170
wm_odf = File(argstr="%s", desc="output WM ODF")
151171
gm_odf = File(argstr="%s", desc="output GM ODF")
152172
csf_odf = File(argstr="%s", desc="output CSF ODF")
173+
predicted_signal = File(desc="output predicted signal")
153174

154175

155176
class EstimateFOD(MRTrix3Base):
@@ -183,10 +204,17 @@ class EstimateFOD(MRTrix3Base):
183204
def _list_outputs(self):
184205
outputs = self.output_spec().get()
185206
outputs["wm_odf"] = op.abspath(self.inputs.wm_odf)
186-
if self.inputs.gm_odf != Undefined:
207+
if isdefined(self.inputs.gm_odf):
187208
outputs["gm_odf"] = op.abspath(self.inputs.gm_odf)
188-
if self.inputs.csf_odf != Undefined:
209+
if isdefined(self.inputs.csf_odf):
189210
outputs["csf_odf"] = op.abspath(self.inputs.csf_odf)
211+
if isdefined(self.inputs.predicted_signal):
212+
if self.inputs.algorithm != "msmt_csd":
213+
raise Exception(
214+
"'predicted_signal' option can only be used with "
215+
"the 'msmt_csd' algorithm"
216+
)
217+
outputs["predicted_signal"] = op.abspath(self.inputs.predicted_signal)
190218
return outputs
191219

192220

Diff for: nipype/interfaces/mrtrix3/tests/test_auto_ConstrainedSphericalDeconvolution.py

+7
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def test_ConstrainedSphericalDeconvolution_inputs():
7777
argstr="-nthreads %d",
7878
nohash=True,
7979
),
80+
predicted_signal=dict(
81+
argstr="-predicted_signal %s",
82+
extensions=None,
83+
),
8084
shell=dict(
8185
argstr="-shell %s",
8286
sep=",",
@@ -112,6 +116,9 @@ def test_ConstrainedSphericalDeconvolution_outputs():
112116
argstr="%s",
113117
extensions=None,
114118
),
119+
predicted_signal=dict(
120+
extensions=None,
121+
),
115122
wm_odf=dict(
116123
argstr="%s",
117124
extensions=None,

Diff for: nipype/interfaces/mrtrix3/tests/test_auto_EstimateFOD.py

+7
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ def test_EstimateFOD_inputs():
8080
argstr="-nthreads %d",
8181
nohash=True,
8282
),
83+
predicted_signal=dict(
84+
argstr="-predicted_signal %s",
85+
extensions=None,
86+
),
8387
shell=dict(
8488
argstr="-shell %s",
8589
sep=",",
@@ -115,6 +119,9 @@ def test_EstimateFOD_outputs():
115119
argstr="%s",
116120
extensions=None,
117121
),
122+
predicted_signal=dict(
123+
extensions=None,
124+
),
118125
wm_odf=dict(
119126
argstr="%s",
120127
extensions=None,

Diff for: nipype/interfaces/mrtrix3/tests/test_auto_FitTensor.py

+7
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def test_FitTensor_inputs():
5454
position=-1,
5555
usedefault=True,
5656
),
57+
predicted_signal=dict(
58+
argstr="-predicted_signal %s",
59+
extensions=None,
60+
),
5761
reg_term=dict(
5862
argstr="-regularisation %f",
5963
max_ver="0.3.13",
@@ -71,6 +75,9 @@ def test_FitTensor_outputs():
7175
out_file=dict(
7276
extensions=None,
7377
),
78+
predicted_signal=dict(
79+
extensions=None,
80+
),
7481
)
7582
outputs = FitTensor.output_spec()
7683

Diff for: nipype/interfaces/mrtrix3/tests/test_auto_SH2Amp.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# AUTO-GENERATED by tools/checkspecs.py - DO NOT EDIT
2+
from ..utils import SH2Amp
3+
4+
5+
def test_SH2Amp_inputs():
6+
input_map = dict(
7+
args=dict(
8+
argstr="%s",
9+
),
10+
directions=dict(
11+
argstr="%s",
12+
extensions=None,
13+
mandatory=True,
14+
position=-2,
15+
),
16+
environ=dict(
17+
nohash=True,
18+
usedefault=True,
19+
),
20+
in_file=dict(
21+
argstr="%s",
22+
extensions=None,
23+
mandatory=True,
24+
position=-3,
25+
),
26+
nonnegative=dict(
27+
argstr="-nonnegative",
28+
),
29+
out_file=dict(
30+
argstr="%s",
31+
extensions=None,
32+
name_source=["in_file"],
33+
name_template="%s_amp.mif",
34+
position=-1,
35+
usedefault=True,
36+
),
37+
)
38+
inputs = SH2Amp.input_spec()
39+
40+
for key, metadata in list(input_map.items()):
41+
for metakey, value in list(metadata.items()):
42+
assert getattr(inputs.traits()[key], metakey) == value
43+
44+
45+
def test_SH2Amp_outputs():
46+
output_map = dict(
47+
out_file=dict(
48+
extensions=None,
49+
),
50+
)
51+
outputs = SH2Amp.output_spec()
52+
53+
for key, metadata in list(output_map.items()):
54+
for metakey, value in list(metadata.items()):
55+
assert getattr(outputs.traits()[key], metakey) == value

Diff for: nipype/interfaces/mrtrix3/tests/test_auto_SHConv.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# AUTO-GENERATED by tools/checkspecs.py - DO NOT EDIT
2+
from ..utils import SHConv
3+
4+
5+
def test_SHConv_inputs():
6+
input_map = dict(
7+
args=dict(
8+
argstr="%s",
9+
),
10+
environ=dict(
11+
nohash=True,
12+
usedefault=True,
13+
),
14+
in_file=dict(
15+
argstr="%s",
16+
extensions=None,
17+
mandatory=True,
18+
position=-3,
19+
),
20+
out_file=dict(
21+
argstr="%s",
22+
extensions=None,
23+
name_source=["in_file"],
24+
name_template="%s_shconv.mif",
25+
position=-1,
26+
usedefault=True,
27+
),
28+
response=dict(
29+
argstr="%s",
30+
extensions=None,
31+
mandatory=True,
32+
position=-2,
33+
),
34+
)
35+
inputs = SHConv.input_spec()
36+
37+
for key, metadata in list(input_map.items()):
38+
for metakey, value in list(metadata.items()):
39+
assert getattr(inputs.traits()[key], metakey) == value
40+
41+
42+
def test_SHConv_outputs():
43+
output_map = dict(
44+
out_file=dict(
45+
extensions=None,
46+
),
47+
)
48+
outputs = SHConv.output_spec()
49+
50+
for key, metadata in list(output_map.items()):
51+
for metakey, value in list(metadata.items()):
52+
assert getattr(outputs.traits()[key], metakey) == value

Diff for: nipype/interfaces/mrtrix3/utils.py

+122
Original file line numberDiff line numberDiff line change
@@ -765,3 +765,125 @@ class MRResize(MRTrix3Base):
765765
_cmd = "mrresize"
766766
input_spec = MRResizeInputSpec
767767
output_spec = MRResizeOutputSpec
768+
769+
770+
class SHConvInputSpec(CommandLineInputSpec):
771+
in_file = File(
772+
exists=True,
773+
argstr="%s",
774+
mandatory=True,
775+
position=-3,
776+
desc="input ODF image",
777+
)
778+
# General options
779+
response = File(
780+
exists=True,
781+
mandatory=True,
782+
argstr="%s",
783+
position=-2,
784+
desc=("The response function"),
785+
)
786+
out_file = File(
787+
name_template="%s_shconv.mif",
788+
name_source=["in_file"],
789+
argstr="%s",
790+
position=-1,
791+
usedefault=True,
792+
desc="the output spherical harmonics",
793+
)
794+
795+
796+
class SHConvOutputSpec(TraitedSpec):
797+
out_file = File(exists=True, desc="the output convoluted spherical harmonics file")
798+
799+
800+
class SHConv(CommandLine):
801+
"""
802+
Convolve spherical harmonics with a tissue response function. Useful for
803+
checking residuals of ODF estimates.
804+
805+
806+
Example
807+
-------
808+
809+
>>> import nipype.interfaces.mrtrix3 as mrt
810+
>>> sh = mrt.SHConv()
811+
>>> sh.inputs.in_file = 'csd.mif'
812+
>>> sh.inputs.response = 'response.txt'
813+
>>> sh.cmdline
814+
'shconv csd.mif response.txt csd_shconv.mif'
815+
>>> sh.run() # doctest: +SKIP
816+
"""
817+
818+
_cmd = "shconv"
819+
input_spec = SHConvInputSpec
820+
output_spec = SHConvOutputSpec
821+
822+
def _list_outputs(self):
823+
outputs = self.output_spec().get()
824+
outputs["out_file"] = op.abspath(self.inputs.out_file)
825+
return outputs
826+
827+
828+
class SH2AmpInputSpec(CommandLineInputSpec):
829+
in_file = File(
830+
exists=True,
831+
argstr="%s",
832+
mandatory=True,
833+
position=-3,
834+
desc="input ODF image",
835+
)
836+
# General options
837+
directions = File(
838+
exists=True,
839+
mandatory=True,
840+
argstr="%s",
841+
position=-2,
842+
desc=(
843+
"The gradient directions along which to sample the spherical "
844+
"harmonics MRtrix format"
845+
),
846+
)
847+
out_file = File(
848+
name_template="%s_amp.mif",
849+
name_source=["in_file"],
850+
argstr="%s",
851+
position=-1,
852+
usedefault=True,
853+
desc="the output spherical harmonics",
854+
)
855+
nonnegative = traits.Bool(
856+
argstr="-nonnegative", desc="cap all negative amplitudes to zero"
857+
)
858+
859+
860+
class SH2AmpOutputSpec(TraitedSpec):
861+
out_file = File(exists=True, desc="the output convoluted spherical harmonics file")
862+
863+
864+
class SH2Amp(CommandLine):
865+
"""
866+
Sample spherical harmonics on a set of gradient orientations. Useful for
867+
checking residuals of ODF estimates.
868+
869+
870+
Example
871+
-------
872+
873+
>>> import nipype.interfaces.mrtrix3 as mrt
874+
>>> sh = mrt.SH2Amp()
875+
>>> sh.inputs.in_file = 'sh.mif'
876+
>>> sh.inputs.directions = 'grads.txt'
877+
>>> sh.cmdline
878+
'sh2amp sh.mif grads.txt sh_amp.mif'
879+
>>> sh.run() # doctest: +SKIP
880+
"""
881+
882+
_cmd = "sh2amp"
883+
input_spec = SH2AmpInputSpec
884+
output_spec = SH2AmpOutputSpec
885+
886+
def _list_outputs(self):
887+
outputs = self.output_spec().get()
888+
outputs["out_file"] = op.abspath(self.inputs.out_file)
889+
return outputs

Diff for: nipype/testing/data/grads.txt

Whitespace-only changes.

Diff for: nipype/testing/data/sh.mif

Whitespace-only changes.

0 commit comments

Comments
 (0)