Skip to content

Commit ec86b74

Browse files
authored
Merge pull request #2005 from mmodat/2004-niftyreg-base-intf
Fixing the base interface _run_interface function for NiftyRegCommand
2 parents cce2f97 + 4757bec commit ec86b74

12 files changed

+61
-24
lines changed

nipype/interfaces/niftyreg/base.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import subprocess
2727
from warnings import warn
2828

29-
from ..base import CommandLine, isdefined, CommandLineInputSpec, traits
29+
from ..base import CommandLine, CommandLineInputSpec, traits, Undefined
3030
from ...utils.filemanip import split_filename
3131

3232

@@ -47,8 +47,9 @@ def no_nifty_package(cmd='reg_f3d'):
4747
class NiftyRegCommandInputSpec(CommandLineInputSpec):
4848
"""Input Spec for niftyreg interfaces."""
4949
# Set the number of omp thread to use
50-
omp_core_val = traits.Int(desc='Number of openmp thread to use',
51-
argstr='-omp %i')
50+
omp_core_val = traits.Int(int(os.environ.get('OMP_NUM_THREADS', '1')),
51+
desc='Number of openmp thread to use',
52+
argstr='-omp %i', usedefault=True)
5253

5354

5455
class NiftyRegCommand(CommandLine):
@@ -58,7 +59,10 @@ class NiftyRegCommand(CommandLine):
5859
_suffix = '_nr'
5960
_min_version = '1.5.30'
6061

62+
input_spec = NiftyRegCommandInputSpec
63+
6164
def __init__(self, required_version=None, **inputs):
65+
self.num_threads = 1
6266
super(NiftyRegCommand, self).__init__(**inputs)
6367
self.required_version = required_version
6468
_version = self.get_version()
@@ -73,6 +77,29 @@ def __init__(self, required_version=None, **inputs):
7377
msg = 'The version of NiftyReg differs from the required'
7478
msg += '(%s != %s)'
7579
warn(msg % (_version, self.required_version))
80+
self.inputs.on_trait_change(self._omp_update, 'omp_core_val')
81+
self.inputs.on_trait_change(self._environ_update, 'environ')
82+
self._omp_update()
83+
84+
def _omp_update(self):
85+
if self.inputs.omp_core_val:
86+
self.inputs.environ['OMP_NUM_THREADS'] = \
87+
str(self.inputs.omp_core_val)
88+
self.num_threads = self.inputs.omp_core_val
89+
else:
90+
if 'OMP_NUM_THREADS' in self.inputs.environ:
91+
del self.inputs.environ['OMP_NUM_THREADS']
92+
self.num_threads = 1
93+
94+
def _environ_update(self):
95+
if self.inputs.environ:
96+
if 'OMP_NUM_THREADS' in self.inputs.environ:
97+
self.inputs.omp_core_val = \
98+
int(self.inputs.environ['OMP_NUM_THREADS'])
99+
else:
100+
self.inputs.omp_core_val = Undefined
101+
else:
102+
self.inputs.omp_core_val = Undefined
76103

77104
def check_version(self):
78105
_version = self.get_version()
@@ -102,13 +129,6 @@ def version(self):
102129
def exists(self):
103130
return self.get_version() is not None
104131

105-
def _run_interface(self, runtime):
106-
# Update num threads estimate from OMP_NUM_THREADS env var
107-
# Default to 1 if not set
108-
if not isdefined(self.inputs.environ['OMP_NUM_THREADS']):
109-
self.inputs.environ['OMP_NUM_THREADS'] = self.num_threads
110-
return super(NiftyRegCommand, self)._run_interface(runtime)
111-
112132
def _format_arg(self, name, spec, value):
113133
if name == 'omp_core_val':
114134
self.numthreads = value

nipype/interfaces/niftyreg/tests/test_auto_NiftyRegCommand.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ def test_NiftyRegCommand_inputs():
1212
ignore_exception=dict(nohash=True,
1313
usedefault=True,
1414
),
15+
omp_core_val=dict(argstr='-omp %i',
16+
usedefault=True,
17+
),
1518
terminal_output=dict(nohash=True,
1619
),
1720
)

nipype/interfaces/niftyreg/tests/test_auto_RegAladin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def test_RegAladin_inputs():
4646
nosym_flag=dict(argstr='-noSym',
4747
),
4848
omp_core_val=dict(argstr='-omp %i',
49+
usedefault=True,
4950
),
5051
platform_val=dict(argstr='-platf %i',
5152
),

nipype/interfaces/niftyreg/tests/test_auto_RegAverage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_RegAverage_inputs():
4343
usedefault=True,
4444
),
4545
omp_core_val=dict(argstr='-omp %i',
46+
usedefault=True,
4647
),
4748
out_file=dict(argstr='%s',
4849
genfile=True,

nipype/interfaces/niftyreg/tests/test_auto_RegF3D.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def test_RegF3D_inputs():
7676
noz_flag=dict(argstr='-noz',
7777
),
7878
omp_core_val=dict(argstr='-omp %i',
79+
usedefault=True,
7980
),
8081
pad_val=dict(argstr='-pad %f',
8182
),

nipype/interfaces/niftyreg/tests/test_auto_RegJacobian.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def test_RegJacobian_inputs():
1313
usedefault=True,
1414
),
1515
omp_core_val=dict(argstr='-omp %i',
16+
usedefault=True,
1617
),
1718
out_file=dict(argstr='%s',
1819
name_source=['trans_file'],

nipype/interfaces/niftyreg/tests/test_auto_RegMeasure.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def test_RegMeasure_inputs():
1919
mandatory=True,
2020
),
2121
omp_core_val=dict(argstr='-omp %i',
22+
usedefault=True,
2223
),
2324
out_file=dict(argstr='-out %s',
2425
name_source=['flo_file'],

nipype/interfaces/niftyreg/tests/test_auto_RegResample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_RegResample_inputs():
1818
inter_val=dict(argstr='-inter %d',
1919
),
2020
omp_core_val=dict(argstr='-omp %i',
21+
usedefault=True,
2122
),
2223
out_file=dict(argstr='%s',
2324
name_source=['flo_file'],

nipype/interfaces/niftyreg/tests/test_auto_RegTools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def test_RegTools_inputs():
3434
noscl_flag=dict(argstr='-noscl',
3535
),
3636
omp_core_val=dict(argstr='-omp %i',
37+
usedefault=True,
3738
),
3839
out_file=dict(argstr='-out %s',
3940
name_source=['in_file'],

nipype/interfaces/niftyreg/tests/test_auto_RegTransform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_RegTransform_inputs():
5757
xor=['def_input', 'disp_input', 'flow_input', 'comp_input', 'upd_s_form_input', 'inv_aff_input', 'inv_nrr_input', 'half_input', 'aff_2_rig_input', 'flirt_2_nr_input'],
5858
),
5959
omp_core_val=dict(argstr='-omp %i',
60+
usedefault=True,
6061
),
6162
out_file=dict(argstr='%s',
6263
genfile=True,

nipype/interfaces/niftyreg/tests/test_regutils.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def test_reg_average():
190190
two_file = example_data('im2.nii')
191191
three_file = example_data('im3.nii')
192192
nr_average.inputs.avg_files = [one_file, two_file, three_file]
193+
nr_average.inputs.omp_core_val = 1
193194
generated_cmd = nr_average.cmdline
194195

195196
# Read the reg_average_cmd
@@ -198,10 +199,10 @@ def test_reg_average():
198199
argv = f_obj.read()
199200
os.remove(reg_average_cmd)
200201

201-
expected_argv = '%s %s -avg %s %s %s' % (get_custom_path('reg_average'),
202-
os.path.join(os.getcwd(),
203-
'avg_out.nii.gz'),
204-
one_file, two_file, three_file)
202+
expected_argv = '%s %s -avg %s %s %s -omp 1' % (
203+
get_custom_path('reg_average'),
204+
os.path.join(os.getcwd(), 'avg_out.nii.gz'),
205+
one_file, two_file, three_file)
205206

206207
assert argv.decode('utf-8') == expected_argv
207208

@@ -217,6 +218,7 @@ def test_reg_average():
217218
two_file = example_data('ants_Affine.txt')
218219
three_file = example_data('elastix.txt')
219220
nr_average_2.inputs.avg_files = [one_file, two_file, three_file]
221+
nr_average_2.inputs.omp_core_val = 1
220222
generated_cmd = nr_average_2.cmdline
221223

222224
# Read the reg_average_cmd
@@ -225,10 +227,10 @@ def test_reg_average():
225227
argv = f_obj.read()
226228
os.remove(reg_average_cmd)
227229

228-
expected_argv = '%s %s -avg %s %s %s' % (get_custom_path('reg_average'),
229-
os.path.join(os.getcwd(),
230-
'avg_out.txt'),
231-
one_file, two_file, three_file)
230+
expected_argv = '%s %s -avg %s %s %s -omp 1' % (
231+
get_custom_path('reg_average'),
232+
os.path.join(os.getcwd(), 'avg_out.txt'),
233+
one_file, two_file, three_file)
232234

233235
assert argv.decode('utf-8') == expected_argv
234236

@@ -238,6 +240,7 @@ def test_reg_average():
238240
two_file = example_data('ants_Affine.txt')
239241
three_file = example_data('elastix.txt')
240242
nr_average_3.inputs.avg_lts_files = [one_file, two_file, three_file]
243+
nr_average_3.inputs.omp_core_val = 1
241244
generated_cmd = nr_average_3.cmdline
242245

243246
# Read the reg_average_cmd
@@ -246,7 +249,7 @@ def test_reg_average():
246249
argv = f_obj.read()
247250
os.remove(reg_average_cmd)
248251

249-
expected_argv = ('%s %s -avg_lts %s %s %s'
252+
expected_argv = ('%s %s -avg_lts %s %s %s -omp 1'
250253
% (get_custom_path('reg_average'),
251254
os.path.join(os.getcwd(), 'avg_out.txt'),
252255
one_file, two_file, three_file))
@@ -266,6 +269,7 @@ def test_reg_average():
266269
trans2_file, two_file,
267270
trans3_file, three_file]
268271
nr_average_4.inputs.avg_ref_file = ref_file
272+
nr_average_4.inputs.omp_core_val = 1
269273
generated_cmd = nr_average_4.cmdline
270274

271275
# Read the reg_average_cmd
@@ -274,12 +278,12 @@ def test_reg_average():
274278
argv = f_obj.read()
275279
os.remove(reg_average_cmd)
276280

277-
expected_argv = ('%s %s -avg_tran %s %s %s %s %s %s %s'
281+
expected_argv = ('%s %s -avg_tran %s -omp 1 %s %s %s %s %s %s'
278282
% (get_custom_path('reg_average'),
279283
os.path.join(os.getcwd(), 'avg_out.nii.gz'),
280284
ref_file, trans1_file, one_file, trans2_file, two_file,
281285
trans3_file, three_file))
282-
286+
283287
assert argv.decode('utf-8') == expected_argv
284288

285289
# Test Reg Average: demean3
@@ -298,6 +302,7 @@ def test_reg_average():
298302
aff2_file, trans2_file, two_file,
299303
aff3_file, trans3_file, three_file]
300304
nr_average_5.inputs.demean3_ref_file = ref_file
305+
nr_average_5.inputs.omp_core_val = 1
301306
generated_cmd = nr_average_5.cmdline
302307

303308
# Read the reg_average_cmd
@@ -306,7 +311,7 @@ def test_reg_average():
306311
argv = f_obj.read()
307312
os.remove(reg_average_cmd)
308313

309-
expected_argv = ('%s %s -demean3 %s %s %s %s %s %s %s %s %s %s'
314+
expected_argv = ('%s %s -demean3 %s -omp 1 %s %s %s %s %s %s %s %s %s'
310315
% (get_custom_path('reg_average'),
311316
os.path.join(os.getcwd(), 'avg_out.nii.gz'),
312317
ref_file,

nipype/interfaces/niftyseg/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
See the docstrings of the individual classes for examples.
1717
"""
1818

19-
from nipype.interfaces.niftyreg.base import NiftyRegCommand, no_nifty_package
19+
from nipype.interfaces.niftyreg.base import no_nifty_package
20+
from nipype.interfaces.niftyfit.base import NiftyFitCommand
2021
import subprocess
2122
import warnings
2223

@@ -25,7 +26,7 @@
2526
warnings.filterwarnings('always', category=UserWarning)
2627

2728

28-
class NiftySegCommand(NiftyRegCommand):
29+
class NiftySegCommand(NiftyFitCommand):
2930
"""
3031
Base support interface for NiftySeg commands.
3132
"""

0 commit comments

Comments
 (0)