|
| 1 | +"""Base class for Dataset |
| 2 | +""" |
| 3 | +from __future__ import annotations |
| 4 | +from rsatoolbox.util.descriptor_utils import check_descriptor_length_error |
| 5 | +from rsatoolbox.util.descriptor_utils import format_descriptor |
| 6 | +from rsatoolbox.util.descriptor_utils import parse_input_descriptor |
| 7 | +from rsatoolbox.util.file_io import write_dict_hdf5 |
| 8 | +from rsatoolbox.util.file_io import write_dict_pkl |
| 9 | +from rsatoolbox.util.file_io import remove_file |
| 10 | + |
| 11 | + |
| 12 | +class DatasetBase: |
| 13 | + """ |
| 14 | + Abstract dataset class. |
| 15 | + Defines members that every class needs to have, but does not |
| 16 | + implement any interesting behavior. Inherit from this class |
| 17 | + to define specific dataset types |
| 18 | +
|
| 19 | + Args: |
| 20 | + measurements (numpy.ndarray): n_obs x n_channel 2d-array, |
| 21 | + descriptors (dict): descriptors (metadata) |
| 22 | + obs_descriptors (dict): observation descriptors (all |
| 23 | + are array-like with shape = (n_obs,...)) |
| 24 | + channel_descriptors (dict): channel descriptors (all are |
| 25 | + array-like with shape = (n_channel,...)) |
| 26 | +
|
| 27 | + Returns: |
| 28 | + dataset object |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__(self, measurements, descriptors=None, |
| 32 | + obs_descriptors=None, channel_descriptors=None, |
| 33 | + check_dims=True): |
| 34 | + if measurements.ndim != 2: |
| 35 | + raise AttributeError( |
| 36 | + "measurements must be in dimension n_obs x n_channel") |
| 37 | + self.measurements = measurements |
| 38 | + self.n_obs, self.n_channel = self.measurements.shape |
| 39 | + if check_dims: |
| 40 | + check_descriptor_length_error(obs_descriptors, |
| 41 | + "obs_descriptors", |
| 42 | + self.n_obs |
| 43 | + ) |
| 44 | + check_descriptor_length_error(channel_descriptors, |
| 45 | + "channel_descriptors", |
| 46 | + self.n_channel |
| 47 | + ) |
| 48 | + self.descriptors = parse_input_descriptor(descriptors) |
| 49 | + self.obs_descriptors = parse_input_descriptor(obs_descriptors) |
| 50 | + self.channel_descriptors = parse_input_descriptor(channel_descriptors) |
| 51 | + |
| 52 | + def __repr__(self): |
| 53 | + """ |
| 54 | + defines string which is printed for the object |
| 55 | + """ |
| 56 | + return (f'rsatoolbox.data.{self.__class__.__name__}(\n' |
| 57 | + f'measurements = \n{self.measurements}\n' |
| 58 | + f'descriptors = \n{self.descriptors}\n' |
| 59 | + f'obs_descriptors = \n{self.obs_descriptors}\n' |
| 60 | + f'channel_descriptors = \n{self.channel_descriptors}\n' |
| 61 | + ) |
| 62 | + |
| 63 | + def __str__(self): |
| 64 | + """ |
| 65 | + defines the output of print |
| 66 | + """ |
| 67 | + string_desc = format_descriptor(self.descriptors) |
| 68 | + string_obs_desc = format_descriptor(self.obs_descriptors) |
| 69 | + string_channel_desc = format_descriptor(self.channel_descriptors) |
| 70 | + if self.measurements.shape[0] > 5: |
| 71 | + measurements = self.measurements[:5, :] |
| 72 | + else: |
| 73 | + measurements = self.measurements |
| 74 | + return (f'rsatoolbox.data.{self.__class__.__name__}\n' |
| 75 | + f'measurements = \n{measurements}\n...\n\n' |
| 76 | + f'descriptors: \n{string_desc}\n\n' |
| 77 | + f'obs_descriptors: \n{string_obs_desc}\n\n' |
| 78 | + f'channel_descriptors: \n{string_channel_desc}\n' |
| 79 | + ) |
| 80 | + |
| 81 | + def __eq__(self, other: DatasetBase) -> bool: |
| 82 | + """Equality check, to be implemented in the specific |
| 83 | + Dataset class |
| 84 | +
|
| 85 | + Args: |
| 86 | + other (DatasetBase): The object to compare to. |
| 87 | +
|
| 88 | + Raises: |
| 89 | + NotImplementedError: This is not valid if not implemented |
| 90 | + by the specific Dataset class |
| 91 | +
|
| 92 | + Returns: |
| 93 | + bool: Never returns |
| 94 | + """ |
| 95 | + raise NotImplementedError() |
| 96 | + |
| 97 | + def copy(self) -> DatasetBase: |
| 98 | + """Copy Dataset |
| 99 | + To be implemented in child class |
| 100 | +
|
| 101 | + Raises: |
| 102 | + NotImplementedError: raised if not implemented |
| 103 | +
|
| 104 | + Returns: |
| 105 | + DatasetBase: Never returns |
| 106 | + """ |
| 107 | + raise NotImplementedError |
| 108 | + |
| 109 | + def split_obs(self, by): |
| 110 | + """ Returns a list Datasets split by obs |
| 111 | +
|
| 112 | + Args: |
| 113 | + by(String): the descriptor by which the splitting is made |
| 114 | +
|
| 115 | + Returns: |
| 116 | + list of Datasets, splitted by the selected obs_descriptor |
| 117 | +
|
| 118 | + """ |
| 119 | + raise NotImplementedError( |
| 120 | + "split_obs function not implemented in used Dataset class!") |
| 121 | + |
| 122 | + def split_channel(self, by): |
| 123 | + """ Returns a list Datasets split by channels |
| 124 | +
|
| 125 | + Args: |
| 126 | + by(String): the descriptor by which the splitting is made |
| 127 | +
|
| 128 | + Returns: |
| 129 | + list of Datasets, splitted by the selected channel_descriptor |
| 130 | +
|
| 131 | + """ |
| 132 | + raise NotImplementedError( |
| 133 | + "split_channel function not implemented in used Dataset class!") |
| 134 | + |
| 135 | + def subset_obs(self, by, value): |
| 136 | + """ Returns a subsetted Dataset defined by certain obs value |
| 137 | +
|
| 138 | + Args: |
| 139 | + by(String): the descriptor by which the subset selection is made |
| 140 | + from obs dimension |
| 141 | + value: the value by which the subset selection is made |
| 142 | + from obs dimension |
| 143 | +
|
| 144 | + Returns: |
| 145 | + Dataset, with subset defined by the selected obs_descriptor |
| 146 | +
|
| 147 | + """ |
| 148 | + raise NotImplementedError( |
| 149 | + "subset_obs function not implemented in used Dataset class!") |
| 150 | + |
| 151 | + def subset_channel(self, by, value): |
| 152 | + """ Returns a subsetted Dataset defined by certain channel value |
| 153 | +
|
| 154 | + Args: |
| 155 | + by(String): the descriptor by which the subset selection is made |
| 156 | + from channel dimension |
| 157 | + value: the value by which the subset selection is made |
| 158 | + from channel dimension |
| 159 | +
|
| 160 | + Returns: |
| 161 | + Dataset, with subset defined by the selected channel_descriptor |
| 162 | +
|
| 163 | + """ |
| 164 | + raise NotImplementedError( |
| 165 | + "subset_channel function not implemented in used Dataset class!") |
| 166 | + |
| 167 | + def save(self, filename, file_type='hdf5', overwrite=False): |
| 168 | + """ Saves the dataset object to a file |
| 169 | +
|
| 170 | + Args: |
| 171 | + filename(String): path to the file |
| 172 | + [or opened file] |
| 173 | + file_type(String): Type of file to create: |
| 174 | + hdf5: hdf5 file |
| 175 | + pkl: pickle file |
| 176 | + overwrite(Boolean): overwrites file if it already exists |
| 177 | +
|
| 178 | + """ |
| 179 | + data_dict = self.to_dict() |
| 180 | + if overwrite: |
| 181 | + remove_file(filename) |
| 182 | + if file_type == 'hdf5': |
| 183 | + write_dict_hdf5(filename, data_dict) |
| 184 | + elif file_type == 'pkl': |
| 185 | + write_dict_pkl(filename, data_dict) |
| 186 | + |
| 187 | + def to_dict(self): |
| 188 | + """ Generates a dictionary which contains the information to |
| 189 | + recreate the dataset object. Used for saving to disc |
| 190 | +
|
| 191 | + Returns: |
| 192 | + data_dict(dict): dictionary with dataset information |
| 193 | +
|
| 194 | + """ |
| 195 | + data_dict = {} |
| 196 | + data_dict['measurements'] = self.measurements |
| 197 | + data_dict['descriptors'] = self.descriptors |
| 198 | + data_dict['obs_descriptors'] = self.obs_descriptors |
| 199 | + data_dict['channel_descriptors'] = self.channel_descriptors |
| 200 | + data_dict['type'] = type(self).__name__ |
| 201 | + return data_dict |
0 commit comments