Skip to content

Commit fe0898f

Browse files
Michiel CottaarMichiel Cottaar
authored andcommitted
ENH: raise error if CIFTI-2 header file has different shape as data
1 parent 0282bb9 commit fe0898f

File tree

3 files changed

+80
-18
lines changed

3 files changed

+80
-18
lines changed

nibabel/cifti2/cifti2.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,36 @@ def _to_xml_element(self):
12091209
mat.append(mim._to_xml_element())
12101210
return mat
12111211

1212+
def get_axis(self, index):
1213+
'''
1214+
Generates the Cifti2 axis for a given dimension
1215+
1216+
Parameters
1217+
----------
1218+
index : int
1219+
Dimension for which we want to obtain the mapping.
1220+
1221+
Returns
1222+
-------
1223+
axis : :class:`.cifti2_axes.Axis`
1224+
'''
1225+
from . import cifti2_axes
1226+
return cifti2_axes.from_index_mapping(self.get_index_map(index))
1227+
1228+
def get_data_shape(self):
1229+
"""
1230+
Returns data shape expected based on the CIFTI-2 header
1231+
"""
1232+
from . import cifti2_axes
1233+
if len(self.mapped_indices) == 0:
1234+
return ()
1235+
base_shape = [-1 for _ in range(max(self.mapped_indices) + 1)]
1236+
for mim in self:
1237+
size = len(cifti2_axes.from_index_mapping(mim))
1238+
for idx in mim.applies_to_matrix_dimension:
1239+
base_shape[idx] = size
1240+
return tuple(base_shape)
1241+
12121242

12131243
class Cifti2Header(FileBasedHeader, xml.XmlSerializable):
12141244
''' Class for CIFTI-2 header extension '''
@@ -1279,8 +1309,7 @@ def get_axis(self, index):
12791309
-------
12801310
axis : :class:`.cifti2_axes.Axis`
12811311
'''
1282-
from . import cifti2_axes
1283-
return cifti2_axes.from_index_mapping(self.matrix.get_index_map(index))
1312+
return self.matrix.get_axis(index)
12841313

12851314
@classmethod
12861315
def from_axes(cls, axes):
@@ -1426,6 +1455,10 @@ def to_file_map(self, file_map=None):
14261455
header = self._nifti_header
14271456
extension = Cifti2Extension(content=self.header.to_xml())
14281457
header.extensions.append(extension)
1458+
if header.get_data_shape() != self.header.matrix.get_data_shape():
1459+
raise ValueError("Dataobj shape {} does not match shape expected from CIFTI-2 header {}".format(
1460+
self._dataobj.shape, self.header.matrix.get_data_shape()
1461+
))
14291462
# if intent code is not set, default to unknown CIFTI
14301463
if header.get_intent()[0] == 'none':
14311464
header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')
@@ -1438,7 +1471,7 @@ def to_file_map(self, file_map=None):
14381471
img.to_file_map(file_map or self.file_map)
14391472

14401473
def update_headers(self):
1441-
''' Harmonize CIFTI-2 and NIfTI headers with image data
1474+
''' Harmonize NIfTI headers with image data
14421475
14431476
>>> import numpy as np
14441477
>>> data = np.zeros((2,3,4))

nibabel/cifti2/tests/test_cifti2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from nibabel import cifti2 as ci
99
from nibabel.nifti2 import Nifti2Header
10-
from nibabel.cifti2.cifti2 import _float_01, _value_if_klass, Cifti2HeaderError
10+
from nibabel.cifti2.cifti2 import _float_01, _value_if_klass, Cifti2HeaderError, Cifti2NamedMap, Cifti2MatrixIndicesMap
1111

1212
from nose.tools import assert_true, assert_equal, assert_raises, assert_is_none
1313

@@ -358,4 +358,10 @@ class TestCifti2ImageAPI(_TDA):
358358
standard_extension = '.nii'
359359

360360
def make_imaker(self, arr, header=None, ni_header=None):
361+
for idx, sz in enumerate(arr.shape):
362+
maps = [Cifti2NamedMap(str(value)) for value in range(sz)]
363+
mim = ci.Cifti2MatrixIndicesMap(
364+
(idx, ), 'CIFTI_INDEX_TYPE_SCALARS', maps=maps
365+
)
366+
header.matrix.append(mim)
361367
return lambda: self.image_maker(arr.copy(), header, ni_header)

nibabel/cifti2/tests/test_new_cifti2.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from nibabel import cifti2 as ci
1313
from nibabel.tmpdirs import InTemporaryDirectory
1414

15-
from nose.tools import assert_true, assert_equal
15+
from nose.tools import assert_true, assert_equal, assert_raises
1616

1717
affine = [[-1.5, 0, 0, 90],
1818
[0, 1.5, 0, -85],
19-
[0, 0, 1.5, -71]]
19+
[0, 0, 1.5, -71],
20+
[0, 0, 0, 1.]]
2021

2122
dimensions = (120, 83, 78)
2223

@@ -234,7 +235,7 @@ def test_dtseries():
234235
matrix.append(series_map)
235236
matrix.append(geometry_map)
236237
hdr = ci.Cifti2Header(matrix)
237-
data = np.random.randn(13, 9)
238+
data = np.random.randn(13, 10)
238239
img = ci.Cifti2Image(data, hdr)
239240
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES')
240241

@@ -257,7 +258,7 @@ def test_dscalar():
257258
matrix.append(scalar_map)
258259
matrix.append(geometry_map)
259260
hdr = ci.Cifti2Header(matrix)
260-
data = np.random.randn(2, 9)
261+
data = np.random.randn(2, 10)
261262
img = ci.Cifti2Image(data, hdr)
262263
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SCALARS')
263264

@@ -279,7 +280,7 @@ def test_dlabel():
279280
matrix.append(label_map)
280281
matrix.append(geometry_map)
281282
hdr = ci.Cifti2Header(matrix)
282-
data = np.random.randn(2, 9)
283+
data = np.random.randn(2, 10)
283284
img = ci.Cifti2Image(data, hdr)
284285
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS')
285286

@@ -299,7 +300,7 @@ def test_dconn():
299300
matrix = ci.Cifti2Matrix()
300301
matrix.append(mapping)
301302
hdr = ci.Cifti2Header(matrix)
302-
data = np.random.randn(9, 9)
303+
data = np.random.randn(10, 10)
303304
img = ci.Cifti2Image(data, hdr)
304305
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE')
305306

@@ -322,7 +323,7 @@ def test_ptseries():
322323
matrix.append(series_map)
323324
matrix.append(parcel_map)
324325
hdr = ci.Cifti2Header(matrix)
325-
data = np.random.randn(13, 3)
326+
data = np.random.randn(13, 4)
326327
img = ci.Cifti2Image(data, hdr)
327328
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES')
328329

@@ -344,7 +345,7 @@ def test_pscalar():
344345
matrix.append(scalar_map)
345346
matrix.append(parcel_map)
346347
hdr = ci.Cifti2Header(matrix)
347-
data = np.random.randn(2, 3)
348+
data = np.random.randn(2, 4)
348349
img = ci.Cifti2Image(data, hdr)
349350
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR')
350351

@@ -366,7 +367,7 @@ def test_pdconn():
366367
matrix.append(geometry_map)
367368
matrix.append(parcel_map)
368369
hdr = ci.Cifti2Header(matrix)
369-
data = np.random.randn(2, 3)
370+
data = np.random.randn(10, 4)
370371
img = ci.Cifti2Image(data, hdr)
371372
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE')
372373

@@ -388,7 +389,7 @@ def test_dpconn():
388389
matrix.append(parcel_map)
389390
matrix.append(geometry_map)
390391
hdr = ci.Cifti2Header(matrix)
391-
data = np.random.randn(2, 3)
392+
data = np.random.randn(4, 10)
392393
img = ci.Cifti2Image(data, hdr)
393394
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED')
394395

@@ -410,7 +411,7 @@ def test_plabel():
410411
matrix.append(label_map)
411412
matrix.append(parcel_map)
412413
hdr = ci.Cifti2Header(matrix)
413-
data = np.random.randn(2, 3)
414+
data = np.random.randn(2, 4)
414415
img = ci.Cifti2Image(data, hdr)
415416

416417
with InTemporaryDirectory():
@@ -429,7 +430,7 @@ def test_pconn():
429430
matrix = ci.Cifti2Matrix()
430431
matrix.append(mapping)
431432
hdr = ci.Cifti2Header(matrix)
432-
data = np.random.randn(3, 3)
433+
data = np.random.randn(4, 4)
433434
img = ci.Cifti2Image(data, hdr)
434435
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED')
435436

@@ -453,7 +454,7 @@ def test_pconnseries():
453454
matrix.append(parcel_map)
454455
matrix.append(series_map)
455456
hdr = ci.Cifti2Header(matrix)
456-
data = np.random.randn(3, 3, 13)
457+
data = np.random.randn(4, 4, 13)
457458
img = ci.Cifti2Image(data, hdr)
458459
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
459460
'PARCELLATED_SERIES')
@@ -479,7 +480,7 @@ def test_pconnscalar():
479480
matrix.append(parcel_map)
480481
matrix.append(scalar_map)
481482
hdr = ci.Cifti2Header(matrix)
482-
data = np.random.randn(3, 3, 13)
483+
data = np.random.randn(4, 4, 2)
483484
img = ci.Cifti2Image(data, hdr)
484485
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
485486
'PARCELLATED_SCALAR')
@@ -496,3 +497,25 @@ def test_pconnscalar():
496497
check_parcel_map(img2.header.matrix.get_index_map(0))
497498
check_scalar_map(img2.header.matrix.get_index_map(2))
498499
del img2
500+
501+
502+
def test_wrong_shape():
503+
scalar_map = create_scalar_map((0, ))
504+
brain_model_map = create_geometry_map((1, ))
505+
506+
matrix = ci.Cifti2Matrix()
507+
matrix.append(scalar_map)
508+
matrix.append(brain_model_map)
509+
hdr = ci.Cifti2Header(matrix)
510+
511+
# correct shape is (2, 10)
512+
for data in (
513+
np.random.randn(1, 11),
514+
np.random.randn(2, 10, 1),
515+
np.random.randn(1, 2, 10),
516+
np.random.randn(3, 10),
517+
np.random.randn(2, 9),
518+
):
519+
img = ci.Cifti2Image(data, hdr)
520+
assert_raises(ValueError, img.to_file_map)
521+

0 commit comments

Comments
 (0)