Skip to content

Commit 76b52f5

Browse files
committed
ENH: Add from_image/from_header methods to bring logic out of tests
1 parent b51ec36 commit 76b52f5

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

Diff for: nibabel/coordimage.py

+40
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22

3+
import nibabel as nib
4+
import nibabel.pointset as ps
35
from nibabel.fileslice import fill_slicer
46

57

@@ -17,6 +19,22 @@ def __init__(self, data, coordaxis, header=None):
1719
self.coordaxis = coordaxis
1820
self.header = header
1921

22+
@classmethod
23+
def from_image(klass, img):
24+
coordaxis = CoordinateAxis.from_header(img.header)
25+
if isinstance(img, nib.Cifti2Image):
26+
if img.ndim != 2:
27+
raise ValueError('Can only interpret 2D images')
28+
for i in img.header.mapped_indices:
29+
if isinstance(img.header.get_axis(i), nib.cifti2.BrainModelAxis):
30+
break
31+
# Reinterpret data ordering based on location of coordinate axis
32+
data = img.dataobj.copy()
33+
data.order = ['F', 'C'][i]
34+
if i == 1:
35+
data._shape = data._shape[::-1]
36+
return klass(data, coordaxis, img.header)
37+
2038

2139
class CoordinateAxis:
2240
"""
@@ -85,6 +103,28 @@ def get_indices(self, parcel, indices=None):
85103
def __len__(self):
86104
return sum(len(parcel) for parcel in self.parcels)
87105

106+
# Hacky factory method for now
107+
@classmethod
108+
def from_header(klass, hdr):
109+
parcels = []
110+
if isinstance(hdr, nib.Cifti2Header):
111+
axes = [hdr.get_axis(i) for i in hdr.mapped_indices]
112+
for ax in axes:
113+
if isinstance(ax, nib.cifti2.BrainModelAxis):
114+
break
115+
else:
116+
raise ValueError('No BrainModelAxis, cannot create CoordinateAxis')
117+
for name, slicer, struct in ax.iter_structures():
118+
if struct.volume_shape:
119+
substruct = ps.NdGrid(struct.volume_shape, struct.affine)
120+
indices = struct.voxel
121+
else:
122+
substruct = None
123+
indices = struct.vertex
124+
parcels.append(Parcel(name, substruct, indices))
125+
126+
return klass(parcels)
127+
88128

89129
class Parcel:
90130
"""

Diff for: nibabel/tests/test_coordimage.py

+5-15
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,12 @@ def from_spec(klass, pathlike):
5151

5252
def test_Cifti2Image_as_CoordImage():
5353
ones = nb.load(CIFTI2_DATA / 'ones.dscalar.nii')
54-
axes = [ones.header.get_axis(i) for i in range(ones.ndim)]
55-
56-
parcels = []
57-
for name, slicer, bma in axes[1].iter_structures():
58-
if bma.volume_shape:
59-
substruct = ps.NdGrid(bma.volume_shape, bma.affine)
60-
indices = bma.voxel
61-
else:
62-
substruct = None
63-
indices = bma.vertex
64-
parcels.append(ci.Parcel(name, None, indices))
65-
caxis = ci.CoordinateAxis(parcels)
66-
dobj = ones.dataobj.copy()
67-
dobj.order = 'C' # Hack for image with BMA as the last axis
68-
cimg = ci.CoordinateImage(dobj, caxis, ones.header)
54+
assert ones.shape == (1, 91282)
55+
cimg = ci.CoordinateImage.from_image(ones)
56+
assert cimg.shape == (91282, 1)
6957

58+
caxis = cimg.coordaxis
59+
assert len(caxis) == 91282
7060
assert caxis[...] is caxis
7161
assert caxis[:] is caxis
7262

0 commit comments

Comments
 (0)