Skip to content

Commit 91fbcce

Browse files
authored
Merge pull request #2673 from effigies/fix/double_squash
FIX: Prevent double-collapsing of nested lists by OutputMultiObject
2 parents 2e69844 + c50839e commit 91fbcce

File tree

2 files changed

+82
-4
lines changed

2 files changed

+82
-4
lines changed

nipype/pipeline/engine/tests/test_nodes.py

+15
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,18 @@ def test_inputs_removal(tmpdir):
290290
n1.overwrite = True
291291
n1.run()
292292
assert not tmpdir.join(n1.name, 'file1.txt').check()
293+
294+
295+
def test_outputmultipath_collapse(tmpdir):
296+
"""Test an OutputMultiPath whose initial value is ``[[x]]`` to ensure that
297+
it is returned as ``[x]``, regardless of how accessed."""
298+
select_if = niu.Select(inlist=[[1, 2, 3], [4]], index=1)
299+
select_nd = pe.Node(niu.Select(inlist=[[1, 2, 3], [4]], index=1),
300+
name='select_nd')
301+
302+
ifres = select_if.run()
303+
ndres = select_nd.run()
304+
305+
assert ifres.outputs.out == [4]
306+
assert ndres.outputs.out == [4]
307+
assert select_nd.result.outputs.out == [4]

nipype/pipeline/engine/utils.py

+67-4
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,78 @@ def write_report(node, report_type=None, is_mapnode=False):
233233
return
234234

235235

236+
def _identify_collapses(hastraits):
237+
""" Identify traits that will collapse when being set to themselves.
238+
239+
``OutputMultiObject``s automatically unwrap a list of length 1 to directly
240+
reference the element of that list.
241+
If that element is itself a list of length 1, then the following will
242+
result in modified values.
243+
244+
hastraits.trait_set(**hastraits.trait_get())
245+
246+
Cloning performs this operation on a copy of the original traited object,
247+
allowing us to identify traits that will be affected.
248+
"""
249+
raw = hastraits.trait_get()
250+
cloned = hastraits.clone_traits().trait_get()
251+
252+
collapsed = set()
253+
for key in cloned:
254+
orig = raw[key]
255+
new = cloned[key]
256+
# Allow numpy to handle the equality checks, as mixed lists and arrays
257+
# can be problematic.
258+
if isinstance(orig, list) and len(orig) == 1 and (
259+
not np.array_equal(orig, new) and np.array_equal(orig[0], new)):
260+
collapsed.add(key)
261+
262+
return collapsed
263+
264+
265+
def _uncollapse(indexable, collapsed):
266+
""" Wrap collapsible values in a list to prevent double-collapsing.
267+
268+
Should be used with _identify_collapses to provide the following
269+
idempotent operation:
270+
271+
collapsed = _identify_collapses(hastraits)
272+
hastraits.trait_set(**_uncollapse(hastraits.trait_get(), collapsed))
273+
274+
NOTE: Modifies object in-place, in addition to returning it.
275+
"""
276+
277+
for key in indexable:
278+
if key in collapsed:
279+
indexable[key] = [indexable[key]]
280+
return indexable
281+
282+
283+
def _protect_collapses(hastraits):
284+
""" A collapse-protected replacement for hastraits.trait_get()
285+
286+
May be used as follows to provide an idempotent trait_set:
287+
288+
hastraits.trait_set(**_protect_collapses(hastraits))
289+
"""
290+
collapsed = _identify_collapses(hastraits)
291+
return _uncollapse(hastraits.trait_get(), collapsed)
292+
293+
236294
def save_resultfile(result, cwd, name):
237295
"""Save a result pklz file to ``cwd``"""
238296
resultsfile = os.path.join(cwd, 'result_%s.pklz' % name)
239297
if result.outputs:
240298
try:
241-
outputs = result.outputs.trait_get()
299+
collapsed = _identify_collapses(result.outputs)
300+
outputs = _uncollapse(result.outputs.trait_get(), collapsed)
301+
# Double-protect tosave so that the original, uncollapsed trait
302+
# is saved in the pickle file. Thus, when the loading process
303+
# collapses, the original correct value is loaded.
304+
tosave = _uncollapse(outputs.copy(), collapsed)
242305
except AttributeError:
243-
outputs = result.outputs.dictcopy() # outputs was a bunch
244-
result.outputs.set(**modify_paths(outputs, relative=True, basedir=cwd))
306+
tosave = outputs = result.outputs.dictcopy() # outputs was a bunch
307+
result.outputs.set(**modify_paths(tosave, relative=True, basedir=cwd))
245308

246309
savepkl(resultsfile, result)
247310
logger.debug('saved results in %s', resultsfile)
@@ -293,7 +356,7 @@ def load_resultfile(path, name):
293356
else:
294357
if result.outputs:
295358
try:
296-
outputs = result.outputs.trait_get()
359+
outputs = _protect_collapses(result.outputs)
297360
except AttributeError:
298361
outputs = result.outputs.dictcopy() # outputs == Bunch
299362
try:

0 commit comments

Comments
 (0)