|
22 | 22 | #
|
23 | 23 | """Data splitting helpers."""
|
24 | 24 |
|
25 |
| -from pathlib import Path |
| 25 | +from __future__ import annotations |
| 26 | + |
| 27 | +from typing import Any |
26 | 28 |
|
27 |
| -import h5py |
28 | 29 | import numpy as np
|
29 | 30 |
|
| 31 | +from nifreeze.data.base import BaseDataset |
| 32 | + |
30 | 33 |
|
31 |
| -def lovo_split(dataset, index, with_b0=False): |
| 34 | +def lovo_split(dataset: BaseDataset, index: int) -> tuple[Any, Any]: |
32 | 35 | """
|
33 | 36 | Produce one fold of LOVO (leave-one-volume-out).
|
34 | 37 |
|
35 | 38 | Parameters
|
36 | 39 | ----------
|
37 |
| - dataset : :obj:`nifreeze.data.dmri.DWI` |
38 |
| - DWI object |
| 40 | + dataset : :obj:`nifreeze.data.base.BaseDataset` |
| 41 | + Dataset object. |
39 | 42 | index : :obj:`int`
|
40 |
| - Index of the DWI orientation to be left out in this fold. |
| 43 | + Index of the volume to be left out in this fold. |
41 | 44 |
|
42 | 45 | Returns
|
43 | 46 | -------
|
44 |
| - (train_data, train_gradients) : :obj:`tuple` |
45 |
| - Training DWI and corresponding gradients. |
46 |
| - Training data/gradients come **from the updated dataset**. |
47 |
| - (test_data, test_gradients) :obj:`tuple` |
48 |
| - Test 3D map (one DWI orientation) and corresponding b-vector/value. |
49 |
| - The test data/gradient come **from the original dataset**. |
| 47 | + :obj:`tuple` of :obj:`tuple` |
| 48 | + A tuple of two elements, the first element being the components |
| 49 | + of the *train* data (including the data themselves and other metadata |
| 50 | + such as gradients for dMRI, or frame times for PET), and the second |
| 51 | + element being the *test* data. |
50 | 52 |
|
51 | 53 | """
|
52 |
| - |
53 |
| - if not Path(dataset.get_filename()).exists(): |
54 |
| - dataset.to_filename(dataset.get_filename()) |
55 |
| - |
56 |
| - # read original DWI data & b-vector |
57 |
| - with h5py.File(dataset.get_filename(), "r") as in_file: |
58 |
| - root = in_file["/0"] |
59 |
| - data = np.asanyarray(root["dataobj"]) |
60 |
| - gradients = np.asanyarray(root["gradients"]) |
61 |
| - |
62 |
| - # if the size of the mask does not match data, cache is stale |
63 |
| - mask = np.zeros(data.shape[-1], dtype=bool) |
| 54 | + mask = np.zeros(len(dataset), dtype=bool) |
64 | 55 | mask[index] = True
|
65 | 56 |
|
66 |
| - train_data = data[..., ~mask] |
67 |
| - train_gradients = gradients[..., ~mask] |
68 |
| - test_data = data[..., mask] |
69 |
| - test_gradients = gradients[..., mask] |
70 |
| - |
71 |
| - if with_b0: |
72 |
| - train_data = np.concatenate( |
73 |
| - (np.asanyarray(dataset.bzero)[..., np.newaxis], train_data), |
74 |
| - axis=-1, |
75 |
| - ) |
76 |
| - b0vec = np.zeros((4, 1)) |
77 |
| - b0vec[0, 0] = 1 |
78 |
| - train_gradients = np.concatenate( |
79 |
| - (b0vec, train_gradients), |
80 |
| - axis=-1, |
81 |
| - ) |
82 |
| - |
83 |
| - return ( |
84 |
| - (train_data, train_gradients), |
85 |
| - (test_data, test_gradients), |
86 |
| - ) |
| 57 | + return (dataset[~mask], dataset[mask]) |
0 commit comments