Skip to content

Commit 0fefaf3

Browse files
authored
Merge pull request #57 from nipreps/enh/generalize-lovo-splitter
ENH: Update LOVO splitter to new dataset indexed access
2 parents d9466cb + 9327b0b commit 0fefaf3

File tree

5 files changed

+35
-63
lines changed

5 files changed

+35
-63
lines changed

src/nifreeze/data/splitting.py

+16-45
Original file line numberDiff line numberDiff line change
@@ -22,65 +22,36 @@
2222
#
2323
"""Data splitting helpers."""
2424

25-
from pathlib import Path
25+
from __future__ import annotations
26+
27+
from typing import Any
2628

27-
import h5py
2829
import numpy as np
2930

31+
from nifreeze.data.base import BaseDataset
32+
3033

31-
def lovo_split(dataset, index, with_b0=False):
34+
def lovo_split(dataset: BaseDataset, index: int) -> tuple[Any, Any]:
3235
"""
3336
Produce one fold of LOVO (leave-one-volume-out).
3437
3538
Parameters
3639
----------
37-
dataset : :obj:`nifreeze.data.dmri.DWI`
38-
DWI object
40+
dataset : :obj:`nifreeze.data.base.BaseDataset`
41+
Dataset object.
3942
index : :obj:`int`
40-
Index of the DWI orientation to be left out in this fold.
43+
Index of the volume to be left out in this fold.
4144
4245
Returns
4346
-------
44-
(train_data, train_gradients) : :obj:`tuple`
45-
Training DWI and corresponding gradients.
46-
Training data/gradients come **from the updated dataset**.
47-
(test_data, test_gradients) :obj:`tuple`
48-
Test 3D map (one DWI orientation) and corresponding b-vector/value.
49-
The test data/gradient come **from the original dataset**.
47+
:obj:`tuple` of :obj:`tuple`
48+
A tuple of two elements, the first element being the components
49+
of the *train* data (including the data themselves and other metadata
50+
such as gradients for dMRI, or frame times for PET), and the second
51+
element being the *test* data.
5052
5153
"""
52-
53-
if not Path(dataset.get_filename()).exists():
54-
dataset.to_filename(dataset.get_filename())
55-
56-
# read original DWI data & b-vector
57-
with h5py.File(dataset.get_filename(), "r") as in_file:
58-
root = in_file["/0"]
59-
data = np.asanyarray(root["dataobj"])
60-
gradients = np.asanyarray(root["gradients"])
61-
62-
# if the size of the mask does not match data, cache is stale
63-
mask = np.zeros(data.shape[-1], dtype=bool)
54+
mask = np.zeros(len(dataset), dtype=bool)
6455
mask[index] = True
6556

66-
train_data = data[..., ~mask]
67-
train_gradients = gradients[..., ~mask]
68-
test_data = data[..., mask]
69-
test_gradients = gradients[..., mask]
70-
71-
if with_b0:
72-
train_data = np.concatenate(
73-
(np.asanyarray(dataset.bzero)[..., np.newaxis], train_data),
74-
axis=-1,
75-
)
76-
b0vec = np.zeros((4, 1))
77-
b0vec[0, 0] = 1
78-
train_gradients = np.concatenate(
79-
(b0vec, train_gradients),
80-
axis=-1,
81-
)
82-
83-
return (
84-
(train_data, train_gradients),
85-
(test_data, test_gradients),
86-
)
57+
return (dataset[~mask], dataset[mask])

src/nifreeze/estimator.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,13 @@ def estimate(
141141
pbar.set_description_str(
142142
f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>"
143143
)
144-
data_train, data_test = lovo_split(data, i, with_b0=True)
145-
grad_str = f"{i}, {data_test[1][:3]}, b={int(data_test[1][3])}"
144+
data_train, data_test = lovo_split(data, i)
145+
grad_str = f"{i}, {data_test[-1][:3]}, b={int(data_test[-1][3])}"
146146
pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs")
147147

148148
if not single_model: # A true LOGO estimator
149149
if hasattr(data, "gradients"):
150-
kwargs["gtab"] = data_train[1]
150+
kwargs["gtab"] = data_train[-1]
151151
# Factory creates the appropriate model and pipes arguments
152152
dwmodel = ModelFactory.init(
153153
model=model,
@@ -162,7 +162,7 @@ def estimate(
162162
)
163163

164164
# generate a synthetic dw volume for the test gradient
165-
predicted = dwmodel.predict(data_test[1])
165+
predicted = dwmodel.predict(data_test[-1])
166166

167167
# prepare data for running ANTs
168168
fixed, moving = _prepare_registration_data(
@@ -180,7 +180,7 @@ def estimate(
180180
data.motion_affines,
181181
data.affine,
182182
data.dataobj.shape[:3],
183-
data_test[1][3],
183+
data_test[-1][3],
184184
i_iter,
185185
i,
186186
ptmp_dir,

test/test_integration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,5 @@ def test_proximity_estimator_trivial_model(datadir, tmp_path):
9696
nt.linear.Affine(est),
9797
xfms[i],
9898
).max()
99-
< 0.2
99+
< 0.25
100100
)

test/test_model.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,19 @@ def test_two_initialisations(datadir):
154154

155155
# Direct initialisation
156156
model1 = model.AverageDWIModel(
157-
gtab=data_train[1],
157+
gtab=data_train[-1],
158158
S0=dmri_dataset.bzero,
159159
th_low=100,
160160
th_high=1000,
161161
bias=False,
162162
stat="mean",
163163
)
164-
model1.fit(data_train[0], gtab=data_train[1])
165-
predicted1 = model1.predict(data_test[1])
164+
model1.fit(data_train[0], gtab=data_train[-1])
165+
predicted1 = model1.predict(data_test[-1])
166166

167167
# Initialisation via ModelFactory
168168
model2 = model.ModelFactory.init(
169-
gtab=data_train[1],
169+
gtab=data_train[-1],
170170
model="avgdwi",
171171
S0=dmri_dataset.bzero,
172172
th_low=100,
@@ -176,9 +176,9 @@ def test_two_initialisations(datadir):
176176
)
177177

178178
with pytest.raises(ModelNotFittedError):
179-
model2.predict(data_test[1])
179+
model2.predict(data_test[-1])
180180

181-
model2.fit(data_train[0], gtab=data_train[1])
182-
predicted2 = model2.predict(data_test[1])
181+
model2.fit(data_train[0], gtab=data_train[-1])
182+
predicted2 = model2.predict(data_test[-1])
183183

184184
assert np.all(predicted1 == predicted2)

test/test_splitting.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def test_lovo_split(datadir):
3737
3838
Returns:
3939
None
40+
4041
"""
4142
data = DWI.from_filename(datadir / "dwi.h5")
4243

@@ -52,11 +53,11 @@ def test_lovo_split(datadir):
5253
data.gradients[..., index] = 1
5354

5455
# Apply the lovo_split function at the specified index
55-
(train_data, train_gradients), (test_data, test_gradients) = lovo_split(data, index)
56+
train_data, test_data = lovo_split(data, index)
5657

5758
# Check if the test data contains only 1s
5859
# and the train data contains only 0s after the split
59-
assert np.all(test_data == 1)
60-
assert np.all(train_data == 0)
61-
assert np.all(test_gradients == 1)
62-
assert np.all(train_gradients == 0)
60+
assert np.all(test_data[0] == 1)
61+
assert np.all(train_data[0] == 0)
62+
assert np.all(test_data[-1] == 1)
63+
assert np.all(train_data[-1] == 0)

0 commit comments

Comments
 (0)